[mlir][EmitC] Update WrapFunInClassPass pass#203641
Conversation
Update the `WrapFuncInClassPass` pass so that `GlobalOp`s are moved into the `ClassOp` as `FieldOps`. This respects MLIR's behavior of resolving references to the closest parent operation that defines a symbol table which is the `ClassOp` that we are creating in this pass. Details: - Identify `GlobalOp`s - Create a `FieldOp` within the `ClassOp` for each `GlobalOp` - Delete the `GlobalOp`s after all functions have been wrapped in a class. Doing this after every function can cause an error when multiple functions refer to the same `GlobalOp`(s) which would be deleted after the first function is wrapped in a class. Also renamed `fName` parameter in `populateWrapFuncInClass` to `funcName` to match naming in `WrapFuncInClass`.
|
Hello @beamandala 👋 Thank you for submitting a Pull Request (PR) to the LLVM Project. Since this is your first PR, here are a few useful links covering our main contribution policies and review practices.
Please reply to this message to confirm that you have read these policies, especially the LLVM AI Tool Use Policy, and that any AI tool usage has been noted in the PR description. Frequently asked questionsHow do I add reviewers? This PR will be automatically labeled, and the relevant teams will be notified. For some parts of the project, reviewers may also be added automatically. You can also add reviewers manually using the Reviewers section on this page. If you cannot use that section, it is probably because you do not have write permissions for the repository. In that case, you can request a review by tagging reviewers in a comment using What if there are no comments? If you have not received any comments on your PR after a week, you can request a review by pinging the PR with a comment such as “Ping”. The common courtesy ping rate is once a week. Please remember that you are asking for volunteer time from other developers. Are any special GitHub settings required to contribute to LLVM? We only require contributors to have a public email address associated with their GitHub commits, see this section of LLVM Developer Policy for details. If you have questions, feel free to leave a comment on this PR, or ask on LLVM Discord or LLVM Discourse. Thank you, |
|
@llvm/pr-subscribers-mlir-emitc Author: Bhavesh M (beamandala) ChangesUpdate the Without this change, references to a Details:
Also renamed Full diff: https://github.com/llvm/llvm-project/pull/203641.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
index 40ecef33448d7..c34c3303a6ab3 100644
--- a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
@@ -20,7 +20,7 @@ def FormExpressionsPass : Pass<"form-expressions"> {
let dependentDialects = ["emitc::EmitCDialect"];
}
-def WrapFuncInClassPass : Pass<"wrap-emitc-func-in-class"> {
+def WrapFuncInClassPass : Pass<"wrap-emitc-func-in-class", "ModuleOp"> {
let summary = "Wrap functions in classes, using arguments as fields.";
let description = [{
This pass transforms `emitc.func` operations into `emitc.class` operations.
diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h b/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h
index 962bdb3c032bf..117a1ef9e2e61 100644
--- a/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h
@@ -32,7 +32,8 @@ void populateExpressionPatterns(RewritePatternSet &patterns);
// The WrapFuncInClass pass.
//===----------------------------------------------------------------------===//
-void populateWrapFuncInClass(RewritePatternSet &patterns, StringRef fName);
+void populateWrapFuncInClass(RewritePatternSet &patterns, StringRef funcName,
+ llvm::SmallVector<emitc::GlobalOp> &globalsToMove);
} // namespace emitc
} // namespace mlir
diff --git a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
index fc8acd616ba70..7f1b9ea95212d 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
@@ -28,12 +28,21 @@ struct WrapFuncInClassPass
: public impl::WrapFuncInClassPassBase<WrapFuncInClassPass> {
using WrapFuncInClassPassBase::WrapFuncInClassPassBase;
void runOnOperation() override {
- Operation *rootOp = getOperation();
+ mlir::ModuleOp moduleOp = getOperation();
+
+ llvm::SmallVector<emitc::GlobalOp> globalsToMove;
+ moduleOp.walk(
+ [&](mlir::emitc::GlobalOp op) { globalsToMove.push_back(op); });
RewritePatternSet patterns(&getContext());
- populateWrapFuncInClass(patterns, funcName);
+ populateWrapFuncInClass(patterns, funcName, globalsToMove);
+
+ walkAndApplyPatterns(moduleOp, std::move(patterns));
- walkAndApplyPatterns(rootOp, std::move(patterns));
+ for (GlobalOp globalOp : globalsToMove) {
+ if (globalOp)
+ globalOp.erase();
+ }
}
};
@@ -43,8 +52,10 @@ struct WrapFuncInClassPass
class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
public:
- WrapFuncInClass(MLIRContext *context, StringRef funcName)
- : OpRewritePattern<emitc::FuncOp>(context), funcName(funcName) {}
+ WrapFuncInClass(MLIRContext *context, StringRef funcName,
+ llvm::SmallVector<emitc::GlobalOp> &globalsToMove)
+ : OpRewritePattern<emitc::FuncOp>(context), funcName(funcName),
+ globalsToMove(globalsToMove) {}
LogicalResult matchAndRewrite(emitc::FuncOp funcOp,
PatternRewriter &rewriter) const override {
@@ -72,6 +83,12 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
}
}
+ for (GlobalOp globalOp : globalsToMove) {
+ emitc::FieldOp::create(rewriter, funcOp->getLoc(),
+ globalOp.getSymNameAttr(), globalOp.getTypeAttr(),
+ globalOp.getInitialValueAttr());
+ }
+
rewriter.setInsertionPointToEnd(&newClassOp.getBody().front());
FunctionType funcType = funcOp.getFunctionType();
Location loc = funcOp.getLoc();
@@ -99,6 +116,14 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
if (failed(newFuncOp.eraseArguments(argsToErase)))
newFuncOp->emitOpError("failed to erase all arguments using BitVector");
+ newFuncOp.walk([&](emitc::GetGlobalOp getGlobalOp) {
+ rewriter.setInsertionPoint(getGlobalOp);
+ emitc::GetFieldOp getFieldOp = emitc::GetFieldOp::create(
+ rewriter, getGlobalOp.getLoc(), getGlobalOp.getType(),
+ getGlobalOp.getNameAttr());
+ rewriter.replaceOp(getGlobalOp, getFieldOp);
+ });
+
rewriter.replaceOp(funcOp, newClassOp);
return success();
}
@@ -107,9 +132,11 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
/// Name of the newly generated member function with body matching the input
/// function.
std::string funcName;
+ llvm::SmallVector<emitc::GlobalOp> globalsToMove;
};
-void mlir::emitc::populateWrapFuncInClass(RewritePatternSet &patterns,
- StringRef funcName) {
- patterns.add<WrapFuncInClass>(patterns.getContext(), funcName);
+void mlir::emitc::populateWrapFuncInClass(
+ RewritePatternSet &patterns, StringRef funcName,
+ llvm::SmallVector<emitc::GlobalOp> &globalsToMove) {
+ patterns.add<WrapFuncInClass>(patterns.getContext(), funcName, globalsToMove);
}
diff --git a/mlir/test/Dialect/EmitC/wrap-func-in-class.mlir b/mlir/test/Dialect/EmitC/wrap-func-in-class.mlir
index cb5f99d31e9da..bdf13ce4df8a4 100644
--- a/mlir/test/Dialect/EmitC/wrap-func-in-class.mlir
+++ b/mlir/test/Dialect/EmitC/wrap-func-in-class.mlir
@@ -1,5 +1,6 @@
// RUN: mlir-opt %s -wrap-emitc-func-in-class -split-input-file | FileCheck %s
// RUN: mlir-opt %s -wrap-emitc-func-in-class=func-name=execute -split-input-file | FileCheck %s --check-prefixes=EXECUTE
+// RUN: mlir-opt %s -wrap-emitc-func-in-class -split-input-file | FileCheck %s
emitc.func @foo(%arg0 : !emitc.array<1xf32>) {
emitc.call_opaque "bar" (%arg0) : (!emitc.array<1xf32>) -> ()
@@ -58,3 +59,24 @@ module attributes { } {
// EXECUTE-NOT: operator
// EXECUTE: execute()
+
+// -----
+
+module attributes { } {
+ emitc.global static const @global_arr : !emitc.array<1xi8> = dense<0>
+ emitc.func @foo() {
+ %0 = emitc.get_global @global_arr : !emitc.array<1xi8>
+ emitc.return
+ }
+}
+
+// CHECK: emitc.class @fooClass {
+// CHECK: emitc.field @global_arr : !emitc.array<1xi8> = dense<0>
+// CHECK: emitc.func @"operator()"() {
+// CHECK: %0 = get_field @global_arr : !emitc.array<1xi8>
+// CHECK: return
+// CHECK: }
+// CHECK: }
+
+// EXECUTE-NOT: operator
+// EXECUTE: execute()
|
|
@llvm/pr-subscribers-mlir Author: Bhavesh M (beamandala) ChangesUpdate the Without this change, references to a Details:
Also renamed Full diff: https://github.com/llvm/llvm-project/pull/203641.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
index 40ecef33448d7..c34c3303a6ab3 100644
--- a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
@@ -20,7 +20,7 @@ def FormExpressionsPass : Pass<"form-expressions"> {
let dependentDialects = ["emitc::EmitCDialect"];
}
-def WrapFuncInClassPass : Pass<"wrap-emitc-func-in-class"> {
+def WrapFuncInClassPass : Pass<"wrap-emitc-func-in-class", "ModuleOp"> {
let summary = "Wrap functions in classes, using arguments as fields.";
let description = [{
This pass transforms `emitc.func` operations into `emitc.class` operations.
diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h b/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h
index 962bdb3c032bf..117a1ef9e2e61 100644
--- a/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h
@@ -32,7 +32,8 @@ void populateExpressionPatterns(RewritePatternSet &patterns);
// The WrapFuncInClass pass.
//===----------------------------------------------------------------------===//
-void populateWrapFuncInClass(RewritePatternSet &patterns, StringRef fName);
+void populateWrapFuncInClass(RewritePatternSet &patterns, StringRef funcName,
+ llvm::SmallVector<emitc::GlobalOp> &globalsToMove);
} // namespace emitc
} // namespace mlir
diff --git a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
index fc8acd616ba70..7f1b9ea95212d 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
@@ -28,12 +28,21 @@ struct WrapFuncInClassPass
: public impl::WrapFuncInClassPassBase<WrapFuncInClassPass> {
using WrapFuncInClassPassBase::WrapFuncInClassPassBase;
void runOnOperation() override {
- Operation *rootOp = getOperation();
+ mlir::ModuleOp moduleOp = getOperation();
+
+ llvm::SmallVector<emitc::GlobalOp> globalsToMove;
+ moduleOp.walk(
+ [&](mlir::emitc::GlobalOp op) { globalsToMove.push_back(op); });
RewritePatternSet patterns(&getContext());
- populateWrapFuncInClass(patterns, funcName);
+ populateWrapFuncInClass(patterns, funcName, globalsToMove);
+
+ walkAndApplyPatterns(moduleOp, std::move(patterns));
- walkAndApplyPatterns(rootOp, std::move(patterns));
+ for (GlobalOp globalOp : globalsToMove) {
+ if (globalOp)
+ globalOp.erase();
+ }
}
};
@@ -43,8 +52,10 @@ struct WrapFuncInClassPass
class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
public:
- WrapFuncInClass(MLIRContext *context, StringRef funcName)
- : OpRewritePattern<emitc::FuncOp>(context), funcName(funcName) {}
+ WrapFuncInClass(MLIRContext *context, StringRef funcName,
+ llvm::SmallVector<emitc::GlobalOp> &globalsToMove)
+ : OpRewritePattern<emitc::FuncOp>(context), funcName(funcName),
+ globalsToMove(globalsToMove) {}
LogicalResult matchAndRewrite(emitc::FuncOp funcOp,
PatternRewriter &rewriter) const override {
@@ -72,6 +83,12 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
}
}
+ for (GlobalOp globalOp : globalsToMove) {
+ emitc::FieldOp::create(rewriter, funcOp->getLoc(),
+ globalOp.getSymNameAttr(), globalOp.getTypeAttr(),
+ globalOp.getInitialValueAttr());
+ }
+
rewriter.setInsertionPointToEnd(&newClassOp.getBody().front());
FunctionType funcType = funcOp.getFunctionType();
Location loc = funcOp.getLoc();
@@ -99,6 +116,14 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
if (failed(newFuncOp.eraseArguments(argsToErase)))
newFuncOp->emitOpError("failed to erase all arguments using BitVector");
+ newFuncOp.walk([&](emitc::GetGlobalOp getGlobalOp) {
+ rewriter.setInsertionPoint(getGlobalOp);
+ emitc::GetFieldOp getFieldOp = emitc::GetFieldOp::create(
+ rewriter, getGlobalOp.getLoc(), getGlobalOp.getType(),
+ getGlobalOp.getNameAttr());
+ rewriter.replaceOp(getGlobalOp, getFieldOp);
+ });
+
rewriter.replaceOp(funcOp, newClassOp);
return success();
}
@@ -107,9 +132,11 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
/// Name of the newly generated member function with body matching the input
/// function.
std::string funcName;
+ llvm::SmallVector<emitc::GlobalOp> globalsToMove;
};
-void mlir::emitc::populateWrapFuncInClass(RewritePatternSet &patterns,
- StringRef funcName) {
- patterns.add<WrapFuncInClass>(patterns.getContext(), funcName);
+void mlir::emitc::populateWrapFuncInClass(
+ RewritePatternSet &patterns, StringRef funcName,
+ llvm::SmallVector<emitc::GlobalOp> &globalsToMove) {
+ patterns.add<WrapFuncInClass>(patterns.getContext(), funcName, globalsToMove);
}
diff --git a/mlir/test/Dialect/EmitC/wrap-func-in-class.mlir b/mlir/test/Dialect/EmitC/wrap-func-in-class.mlir
index cb5f99d31e9da..bdf13ce4df8a4 100644
--- a/mlir/test/Dialect/EmitC/wrap-func-in-class.mlir
+++ b/mlir/test/Dialect/EmitC/wrap-func-in-class.mlir
@@ -1,5 +1,6 @@
// RUN: mlir-opt %s -wrap-emitc-func-in-class -split-input-file | FileCheck %s
// RUN: mlir-opt %s -wrap-emitc-func-in-class=func-name=execute -split-input-file | FileCheck %s --check-prefixes=EXECUTE
+// RUN: mlir-opt %s -wrap-emitc-func-in-class -split-input-file | FileCheck %s
emitc.func @foo(%arg0 : !emitc.array<1xf32>) {
emitc.call_opaque "bar" (%arg0) : (!emitc.array<1xf32>) -> ()
@@ -58,3 +59,24 @@ module attributes { } {
// EXECUTE-NOT: operator
// EXECUTE: execute()
+
+// -----
+
+module attributes { } {
+ emitc.global static const @global_arr : !emitc.array<1xi8> = dense<0>
+ emitc.func @foo() {
+ %0 = emitc.get_global @global_arr : !emitc.array<1xi8>
+ emitc.return
+ }
+}
+
+// CHECK: emitc.class @fooClass {
+// CHECK: emitc.field @global_arr : !emitc.array<1xi8> = dense<0>
+// CHECK: emitc.func @"operator()"() {
+// CHECK: %0 = get_field @global_arr : !emitc.array<1xi8>
+// CHECK: return
+// CHECK: }
+// CHECK: }
+
+// EXECUTE-NOT: operator
+// EXECUTE: execute()
|
Update the
WrapFuncInClassPasspass so thatGlobalOps are moved into theClassOpasFieldOps. This respects MLIR's behavior of resolving references to the closest parent operation that defines a symbol table which is theClassOpthat we are creating in this pass.Without this change, references to a
GlobalOpinGetGlobalOpare failing to resolve.Details:
GlobalOpsFieldOpwithin theClassOpfor eachGlobalOpGlobalOps after all functions have been wrapped in a class. Doing this after every function can cause an error when multiple functions refer to the sameGlobalOp(s) which would be deleted after the first function is wrapped in a class.Also renamed
fNameparameter inpopulateWrapFuncInClasstofuncNameto match naming inWrapFuncInClass.