Skip to content

[mlir][EmitC] Update WrapFunInClassPass pass#203641

Open
beamandala wants to merge 1 commit into
llvm:mainfrom
beamandala:bhavesh/fix-wrap-emitc-func-in-class-pass
Open

[mlir][EmitC] Update WrapFunInClassPass pass#203641
beamandala wants to merge 1 commit into
llvm:mainfrom
beamandala:bhavesh/fix-wrap-emitc-func-in-class-pass

Conversation

@beamandala

Copy link
Copy Markdown

Update the WrapFuncInClassPass pass so that GlobalOps are moved into the ClassOp as FieldOps. This respects MLIR's behavior of resolving references to the closest parent operation that defines a symbol table which is the ClassOp that we are creating in this pass.

Without this change, references to a GlobalOp in GetGlobalOp are failing to resolve.

Details:

  • Identify GlobalOps
  • Create a FieldOp within the ClassOp for each GlobalOp
  • Delete the GlobalOps after all functions have been wrapped in a class. Doing this after every function can cause an error when multiple functions refer to the same GlobalOp(s) which would be deleted after the first function is wrapped in a class.

Also renamed fName parameter in populateWrapFuncInClass to funcName to match naming in WrapFuncInClass.

Update the `WrapFuncInClassPass` pass so that `GlobalOp`s are moved into
the `ClassOp` as `FieldOps`. This respects MLIR's behavior of resolving
references to the closest parent operation that defines a symbol table
which is the `ClassOp` that we are creating in this pass.

Details:
- Identify `GlobalOp`s
- Create a `FieldOp` within the `ClassOp` for each `GlobalOp`
- Delete the `GlobalOp`s after all functions have been wrapped in a
  class. Doing this after every function can cause an error when
multiple functions refer to the same `GlobalOp`(s) which would be
deleted after the first function is wrapped in a class.

Also renamed `fName` parameter in `populateWrapFuncInClass` to
`funcName` to match naming in `WrapFuncInClass`.
@github-actions

Copy link
Copy Markdown

Hello @beamandala 👋

Thank you for submitting a Pull Request (PR) to the LLVM Project. Since this is your first PR, here are a few useful links covering our main contribution policies and review practices.

  • All contributions to LLVM must follow our LLVM AI Tool Use Policy. In particular, if you used AI while working on this PR, remember to add a note to the PR description.
  • The LLVM Code-Review Policy and Practices document contains practical information about the PR process, including how patches are reviewed and accepted, and who can review a PR.
  • Our LLVM Developer Policy describes our expectations for code quality, commit summaries and contains notes on our CI system.

Please reply to this message to confirm that you have read these policies, especially the LLVM AI Tool Use Policy, and that any AI tool usage has been noted in the PR description.


Frequently asked questions

How do I add reviewers?

This PR will be automatically labeled, and the relevant teams will be notified. For some parts of the project, reviewers may also be added automatically.

You can also add reviewers manually using the Reviewers section on this page. If you cannot use that section, it is probably because you do not have write permissions for the repository. In that case, you can request a review by tagging reviewers in a comment using @ followed by their GitHub username.

What if there are no comments?

If you have not received any comments on your PR after a week, you can request a review by pinging the PR with a comment such as “Ping”. The common courtesy ping rate is once a week. Please remember that you are asking for volunteer time from other developers.

Are any special GitHub settings required to contribute to LLVM?

We only require contributors to have a public email address associated with their GitHub commits, see this section of LLVM Developer Policy for details.


If you have questions, feel free to leave a comment on this PR, or ask on LLVM Discord or LLVM Discourse.

Thank you,
The LLVM Community

@beamandala

Copy link
Copy Markdown
Author

@llvmorg-github-actions

Copy link
Copy Markdown

@llvm/pr-subscribers-mlir-emitc

Author: Bhavesh M (beamandala)

Changes

Update the WrapFuncInClassPass pass so that GlobalOps are moved into the ClassOp as FieldOps. This respects MLIR's behavior of resolving references to the closest parent operation that defines a symbol table which is the ClassOp that we are creating in this pass.

Without this change, references to a GlobalOp in GetGlobalOp are failing to resolve.

Details:

  • Identify GlobalOps
  • Create a FieldOp within the ClassOp for each GlobalOp
  • Delete the GlobalOps after all functions have been wrapped in a class. Doing this after every function can cause an error when multiple functions refer to the same GlobalOp(s) which would be deleted after the first function is wrapped in a class.

Also renamed fName parameter in populateWrapFuncInClass to funcName to match naming in WrapFuncInClass.


Full diff: https://github.com/llvm/llvm-project/pull/203641.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td (+1-1)
  • (modified) mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h (+2-1)
  • (modified) mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp (+35-8)
  • (modified) mlir/test/Dialect/EmitC/wrap-func-in-class.mlir (+22)
diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
index 40ecef33448d7..c34c3303a6ab3 100644
--- a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
@@ -20,7 +20,7 @@ def FormExpressionsPass : Pass<"form-expressions"> {
   let dependentDialects = ["emitc::EmitCDialect"];
 }
 
-def WrapFuncInClassPass : Pass<"wrap-emitc-func-in-class"> {
+def WrapFuncInClassPass : Pass<"wrap-emitc-func-in-class", "ModuleOp"> {
   let summary = "Wrap functions in classes, using arguments as fields.";
   let description = [{
     This pass transforms `emitc.func` operations into `emitc.class` operations.
diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h b/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h
index 962bdb3c032bf..117a1ef9e2e61 100644
--- a/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h
@@ -32,7 +32,8 @@ void populateExpressionPatterns(RewritePatternSet &patterns);
 // The WrapFuncInClass pass.
 //===----------------------------------------------------------------------===//
 
-void populateWrapFuncInClass(RewritePatternSet &patterns, StringRef fName);
+void populateWrapFuncInClass(RewritePatternSet &patterns, StringRef funcName,
+                             llvm::SmallVector<emitc::GlobalOp> &globalsToMove);
 
 } // namespace emitc
 } // namespace mlir
diff --git a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
index fc8acd616ba70..7f1b9ea95212d 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
@@ -28,12 +28,21 @@ struct WrapFuncInClassPass
     : public impl::WrapFuncInClassPassBase<WrapFuncInClassPass> {
   using WrapFuncInClassPassBase::WrapFuncInClassPassBase;
   void runOnOperation() override {
-    Operation *rootOp = getOperation();
+    mlir::ModuleOp moduleOp = getOperation();
+
+    llvm::SmallVector<emitc::GlobalOp> globalsToMove;
+    moduleOp.walk(
+        [&](mlir::emitc::GlobalOp op) { globalsToMove.push_back(op); });
 
     RewritePatternSet patterns(&getContext());
-    populateWrapFuncInClass(patterns, funcName);
+    populateWrapFuncInClass(patterns, funcName, globalsToMove);
+
+    walkAndApplyPatterns(moduleOp, std::move(patterns));
 
-    walkAndApplyPatterns(rootOp, std::move(patterns));
+    for (GlobalOp globalOp : globalsToMove) {
+      if (globalOp)
+        globalOp.erase();
+    }
   }
 };
 
@@ -43,8 +52,10 @@ struct WrapFuncInClassPass
 
 class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
 public:
-  WrapFuncInClass(MLIRContext *context, StringRef funcName)
-      : OpRewritePattern<emitc::FuncOp>(context), funcName(funcName) {}
+  WrapFuncInClass(MLIRContext *context, StringRef funcName,
+                  llvm::SmallVector<emitc::GlobalOp> &globalsToMove)
+      : OpRewritePattern<emitc::FuncOp>(context), funcName(funcName),
+        globalsToMove(globalsToMove) {}
 
   LogicalResult matchAndRewrite(emitc::FuncOp funcOp,
                                 PatternRewriter &rewriter) const override {
@@ -72,6 +83,12 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
       }
     }
 
+    for (GlobalOp globalOp : globalsToMove) {
+      emitc::FieldOp::create(rewriter, funcOp->getLoc(),
+                             globalOp.getSymNameAttr(), globalOp.getTypeAttr(),
+                             globalOp.getInitialValueAttr());
+    }
+
     rewriter.setInsertionPointToEnd(&newClassOp.getBody().front());
     FunctionType funcType = funcOp.getFunctionType();
     Location loc = funcOp.getLoc();
@@ -99,6 +116,14 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
     if (failed(newFuncOp.eraseArguments(argsToErase)))
       newFuncOp->emitOpError("failed to erase all arguments using BitVector");
 
+    newFuncOp.walk([&](emitc::GetGlobalOp getGlobalOp) {
+      rewriter.setInsertionPoint(getGlobalOp);
+      emitc::GetFieldOp getFieldOp = emitc::GetFieldOp::create(
+          rewriter, getGlobalOp.getLoc(), getGlobalOp.getType(),
+          getGlobalOp.getNameAttr());
+      rewriter.replaceOp(getGlobalOp, getFieldOp);
+    });
+
     rewriter.replaceOp(funcOp, newClassOp);
     return success();
   }
@@ -107,9 +132,11 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
   /// Name of the newly generated member function with body matching the input
   /// function.
   std::string funcName;
+  llvm::SmallVector<emitc::GlobalOp> globalsToMove;
 };
 
-void mlir::emitc::populateWrapFuncInClass(RewritePatternSet &patterns,
-                                          StringRef funcName) {
-  patterns.add<WrapFuncInClass>(patterns.getContext(), funcName);
+void mlir::emitc::populateWrapFuncInClass(
+    RewritePatternSet &patterns, StringRef funcName,
+    llvm::SmallVector<emitc::GlobalOp> &globalsToMove) {
+  patterns.add<WrapFuncInClass>(patterns.getContext(), funcName, globalsToMove);
 }
diff --git a/mlir/test/Dialect/EmitC/wrap-func-in-class.mlir b/mlir/test/Dialect/EmitC/wrap-func-in-class.mlir
index cb5f99d31e9da..bdf13ce4df8a4 100644
--- a/mlir/test/Dialect/EmitC/wrap-func-in-class.mlir
+++ b/mlir/test/Dialect/EmitC/wrap-func-in-class.mlir
@@ -1,5 +1,6 @@
 // RUN: mlir-opt %s -wrap-emitc-func-in-class -split-input-file | FileCheck %s
 // RUN: mlir-opt %s -wrap-emitc-func-in-class=func-name=execute -split-input-file | FileCheck %s --check-prefixes=EXECUTE
+// RUN: mlir-opt %s -wrap-emitc-func-in-class -split-input-file | FileCheck %s
 
 emitc.func @foo(%arg0 : !emitc.array<1xf32>) {
   emitc.call_opaque "bar" (%arg0) : (!emitc.array<1xf32>) -> ()
@@ -58,3 +59,24 @@ module attributes { } {
 
 // EXECUTE-NOT: operator
 // EXECUTE: execute()
+
+// -----
+
+module attributes { } {
+  emitc.global static const @global_arr : !emitc.array<1xi8> = dense<0>
+  emitc.func @foo() {
+    %0 = emitc.get_global @global_arr : !emitc.array<1xi8>
+    emitc.return
+  }
+}
+
+// CHECK:   emitc.class @fooClass {
+// CHECK:     emitc.field @global_arr : !emitc.array<1xi8> = dense<0>
+// CHECK:     emitc.func @"operator()"() {
+// CHECK:       %0 = get_field @global_arr : !emitc.array<1xi8>
+// CHECK:       return
+// CHECK:     }
+// CHECK:   }
+
+// EXECUTE-NOT: operator
+// EXECUTE: execute()

@llvmorg-github-actions

Copy link
Copy Markdown

@llvm/pr-subscribers-mlir

Author: Bhavesh M (beamandala)

Changes

Update the WrapFuncInClassPass pass so that GlobalOps are moved into the ClassOp as FieldOps. This respects MLIR's behavior of resolving references to the closest parent operation that defines a symbol table which is the ClassOp that we are creating in this pass.

Without this change, references to a GlobalOp in GetGlobalOp are failing to resolve.

Details:

  • Identify GlobalOps
  • Create a FieldOp within the ClassOp for each GlobalOp
  • Delete the GlobalOps after all functions have been wrapped in a class. Doing this after every function can cause an error when multiple functions refer to the same GlobalOp(s) which would be deleted after the first function is wrapped in a class.

Also renamed fName parameter in populateWrapFuncInClass to funcName to match naming in WrapFuncInClass.


Full diff: https://github.com/llvm/llvm-project/pull/203641.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td (+1-1)
  • (modified) mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h (+2-1)
  • (modified) mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp (+35-8)
  • (modified) mlir/test/Dialect/EmitC/wrap-func-in-class.mlir (+22)
diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
index 40ecef33448d7..c34c3303a6ab3 100644
--- a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
@@ -20,7 +20,7 @@ def FormExpressionsPass : Pass<"form-expressions"> {
   let dependentDialects = ["emitc::EmitCDialect"];
 }
 
-def WrapFuncInClassPass : Pass<"wrap-emitc-func-in-class"> {
+def WrapFuncInClassPass : Pass<"wrap-emitc-func-in-class", "ModuleOp"> {
   let summary = "Wrap functions in classes, using arguments as fields.";
   let description = [{
     This pass transforms `emitc.func` operations into `emitc.class` operations.
diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h b/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h
index 962bdb3c032bf..117a1ef9e2e61 100644
--- a/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h
@@ -32,7 +32,8 @@ void populateExpressionPatterns(RewritePatternSet &patterns);
 // The WrapFuncInClass pass.
 //===----------------------------------------------------------------------===//
 
-void populateWrapFuncInClass(RewritePatternSet &patterns, StringRef fName);
+void populateWrapFuncInClass(RewritePatternSet &patterns, StringRef funcName,
+                             llvm::SmallVector<emitc::GlobalOp> &globalsToMove);
 
 } // namespace emitc
 } // namespace mlir
diff --git a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
index fc8acd616ba70..7f1b9ea95212d 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
@@ -28,12 +28,21 @@ struct WrapFuncInClassPass
     : public impl::WrapFuncInClassPassBase<WrapFuncInClassPass> {
   using WrapFuncInClassPassBase::WrapFuncInClassPassBase;
   void runOnOperation() override {
-    Operation *rootOp = getOperation();
+    mlir::ModuleOp moduleOp = getOperation();
+
+    llvm::SmallVector<emitc::GlobalOp> globalsToMove;
+    moduleOp.walk(
+        [&](mlir::emitc::GlobalOp op) { globalsToMove.push_back(op); });
 
     RewritePatternSet patterns(&getContext());
-    populateWrapFuncInClass(patterns, funcName);
+    populateWrapFuncInClass(patterns, funcName, globalsToMove);
+
+    walkAndApplyPatterns(moduleOp, std::move(patterns));
 
-    walkAndApplyPatterns(rootOp, std::move(patterns));
+    for (GlobalOp globalOp : globalsToMove) {
+      if (globalOp)
+        globalOp.erase();
+    }
   }
 };
 
@@ -43,8 +52,10 @@ struct WrapFuncInClassPass
 
 class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
 public:
-  WrapFuncInClass(MLIRContext *context, StringRef funcName)
-      : OpRewritePattern<emitc::FuncOp>(context), funcName(funcName) {}
+  WrapFuncInClass(MLIRContext *context, StringRef funcName,
+                  llvm::SmallVector<emitc::GlobalOp> &globalsToMove)
+      : OpRewritePattern<emitc::FuncOp>(context), funcName(funcName),
+        globalsToMove(globalsToMove) {}
 
   LogicalResult matchAndRewrite(emitc::FuncOp funcOp,
                                 PatternRewriter &rewriter) const override {
@@ -72,6 +83,12 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
       }
     }
 
+    for (GlobalOp globalOp : globalsToMove) {
+      emitc::FieldOp::create(rewriter, funcOp->getLoc(),
+                             globalOp.getSymNameAttr(), globalOp.getTypeAttr(),
+                             globalOp.getInitialValueAttr());
+    }
+
     rewriter.setInsertionPointToEnd(&newClassOp.getBody().front());
     FunctionType funcType = funcOp.getFunctionType();
     Location loc = funcOp.getLoc();
@@ -99,6 +116,14 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
     if (failed(newFuncOp.eraseArguments(argsToErase)))
       newFuncOp->emitOpError("failed to erase all arguments using BitVector");
 
+    newFuncOp.walk([&](emitc::GetGlobalOp getGlobalOp) {
+      rewriter.setInsertionPoint(getGlobalOp);
+      emitc::GetFieldOp getFieldOp = emitc::GetFieldOp::create(
+          rewriter, getGlobalOp.getLoc(), getGlobalOp.getType(),
+          getGlobalOp.getNameAttr());
+      rewriter.replaceOp(getGlobalOp, getFieldOp);
+    });
+
     rewriter.replaceOp(funcOp, newClassOp);
     return success();
   }
@@ -107,9 +132,11 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
   /// Name of the newly generated member function with body matching the input
   /// function.
   std::string funcName;
+  llvm::SmallVector<emitc::GlobalOp> globalsToMove;
 };
 
-void mlir::emitc::populateWrapFuncInClass(RewritePatternSet &patterns,
-                                          StringRef funcName) {
-  patterns.add<WrapFuncInClass>(patterns.getContext(), funcName);
+void mlir::emitc::populateWrapFuncInClass(
+    RewritePatternSet &patterns, StringRef funcName,
+    llvm::SmallVector<emitc::GlobalOp> &globalsToMove) {
+  patterns.add<WrapFuncInClass>(patterns.getContext(), funcName, globalsToMove);
 }
diff --git a/mlir/test/Dialect/EmitC/wrap-func-in-class.mlir b/mlir/test/Dialect/EmitC/wrap-func-in-class.mlir
index cb5f99d31e9da..bdf13ce4df8a4 100644
--- a/mlir/test/Dialect/EmitC/wrap-func-in-class.mlir
+++ b/mlir/test/Dialect/EmitC/wrap-func-in-class.mlir
@@ -1,5 +1,6 @@
 // RUN: mlir-opt %s -wrap-emitc-func-in-class -split-input-file | FileCheck %s
 // RUN: mlir-opt %s -wrap-emitc-func-in-class=func-name=execute -split-input-file | FileCheck %s --check-prefixes=EXECUTE
+// RUN: mlir-opt %s -wrap-emitc-func-in-class -split-input-file | FileCheck %s
 
 emitc.func @foo(%arg0 : !emitc.array<1xf32>) {
   emitc.call_opaque "bar" (%arg0) : (!emitc.array<1xf32>) -> ()
@@ -58,3 +59,24 @@ module attributes { } {
 
 // EXECUTE-NOT: operator
 // EXECUTE: execute()
+
+// -----
+
+module attributes { } {
+  emitc.global static const @global_arr : !emitc.array<1xi8> = dense<0>
+  emitc.func @foo() {
+    %0 = emitc.get_global @global_arr : !emitc.array<1xi8>
+    emitc.return
+  }
+}
+
+// CHECK:   emitc.class @fooClass {
+// CHECK:     emitc.field @global_arr : !emitc.array<1xi8> = dense<0>
+// CHECK:     emitc.func @"operator()"() {
+// CHECK:       %0 = get_field @global_arr : !emitc.array<1xi8>
+// CHECK:       return
+// CHECK:     }
+// CHECK:   }
+
+// EXECUTE-NOT: operator
+// EXECUTE: execute()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant