From 9db62256d1105705ee2665c9e30f7331cdc0af28 Mon Sep 17 00:00:00 2001 From: duongynhi000005-oss Date: Thu, 11 Jun 2026 14:21:29 +0000 Subject: [PATCH] fix: [Bounty $50] Easy message passing interface Automated fix by AI agent (codex). --- src/main/java/com/aparapi/Kernel.java | 50 +++++++++++ .../aparapi/internal/writer/KernelWriter.java | 86 +++++++++++++++++++ .../codegen/test/LocalMessagePassing.java | 32 +++++++ .../codegen/test/LocalMessagePassingTest.java | 79 +++++++++++++++++ .../runtime/LocalMessagePassingTest.java | 52 +++++++++++ 5 files changed, 299 insertions(+) create mode 100644 src/test/java/com/aparapi/codegen/test/LocalMessagePassing.java create mode 100644 src/test/java/com/aparapi/codegen/test/LocalMessagePassingTest.java create mode 100644 src/test/java/com/aparapi/runtime/LocalMessagePassingTest.java diff --git a/src/main/java/com/aparapi/Kernel.java b/src/main/java/com/aparapi/Kernel.java index 4b9686db..a2f7f62f 100644 --- a/src/main/java/com/aparapi/Kernel.java +++ b/src/main/java/com/aparapi/Kernel.java @@ -2452,6 +2452,56 @@ protected final int atomicXor(AtomicInteger p, int val) { return p.getAndAccumulate(val, xorOperator); } + /** + * Exchanges an integer value with another work-item in the same work-group using a caller supplied local-memory + * buffer. + *
+ * Every work-item in the work-group must call this method, and _mailbox must be a {@link Local} + * buffer, or a buffer named with {@link #LOCAL_SUFFIX}, with at least getLocalSize() entries. Each + * work-item publishes _value at its local id, waits for the work-group, reads the value published by + * _targetLocalId, and waits again before returning. As with {@link #localBarrier()}, all work-items + * must reach the same calls in the same order. + * + * @param _mailbox local-memory buffer used as a mailbox between work-items + * @param _targetLocalId local id whose value should be read by the current work-item + * @param _value value published by the current work-item + * @return value published by _targetLocalId + */ + @OpenCLMapping + @Experimental + protected final int localExchange(int[] _mailbox, int _targetLocalId, int _value) { + _mailbox[getLocalId()] = _value; + localBarrier(); + final int value = _mailbox[_targetLocalId]; + localBarrier(); + return value; + } + + /** + * Broadcasts an integer value from one work-item to the rest of the current work-group. + *
+ * Every work-item in the work-group must call this method, and _mailbox must be a {@link Local} + * buffer, or a buffer named with {@link #LOCAL_SUFFIX}, with at least one entry. The work-item whose local id equals + * _sourceLocalId writes _value, then all work-items read and return that value. As with + * {@link #localBarrier()}, all work-items must reach the same calls in the same order. + * + * @param _mailbox local-memory buffer used to publish the broadcast value + * @param _sourceLocalId local id of the work-item providing the broadcast value + * @param _value value published by the source work-item + * @return value published by _sourceLocalId + */ + @OpenCLMapping + @Experimental + protected final int localBroadcast(int[] _mailbox, int _sourceLocalId, int _value) { + if (getLocalId() == _sourceLocalId) { + _mailbox[0] = _value; + } + localBarrier(); + final int value = _mailbox[0]; + localBarrier(); + return value; + } + /** * Wait for all kernels in the current work group to rendezvous at this call before continuing execution.
* It will also enforce memory ordering, such that modifications made by each thread in the work-group, to the memory, diff --git a/src/main/java/com/aparapi/internal/writer/KernelWriter.java b/src/main/java/com/aparapi/internal/writer/KernelWriter.java index 0a72a33f..6073f837 100644 --- a/src/main/java/com/aparapi/internal/writer/KernelWriter.java +++ b/src/main/java/com/aparapi/internal/writer/KernelWriter.java @@ -318,6 +318,80 @@ public void writePragma(String _name, boolean _enable) { newLine(); } + private boolean usesLocalMessagePassing(Entrypoint _entryPoint, String _methodName) { + if (usesLocalMessagePassing(_entryPoint.getMethodModel(), _methodName)) { + return true; + } + for (final MethodModel methodModel : _entryPoint.getCalledMethods()) { + if (usesLocalMessagePassing(methodModel, _methodName)) { + return true; + } + } + return false; + } + + private boolean usesLocalMessagePassing(MethodModel _methodModel, String _methodName) { + for (final MethodCall methodCall : _methodModel.getMethodCalls()) { + final MethodEntry methodEntry = methodCall.getConstantPoolMethodEntry(); + final String methodName = methodEntry.getNameAndTypeEntry().getNameUTF8Entry().getUTF8(); + final String methodSignature = methodEntry.getNameAndTypeEntry().getDescriptorUTF8Entry().getUTF8(); + if (Kernel.isMappedMethod(methodEntry) && _methodName.equals(methodName) && "([III)I".equals(methodSignature)) { + return true; + } + } + return false; + } + + private void writeLocalExchangeHelper() { + write("int localExchange(__local int *_mailbox, int _targetLocalId, int _value){"); + in(); + { + newLine(); + write("_mailbox[get_local_id(0)] = _value;"); + newLine(); + write("barrier(CLK_LOCAL_MEM_FENCE);"); + newLine(); + write("int value = _mailbox[_targetLocalId];"); + newLine(); + write("barrier(CLK_LOCAL_MEM_FENCE);"); + newLine(); + write("return value;"); + out(); + newLine(); + } + write("}"); + newLine(); + } + + private void writeLocalBroadcastHelper() { + write("int localBroadcast(__local int *_mailbox, int _sourceLocalId, int _value){"); + in(); + { + newLine(); + write("if (get_local_id(0)==_sourceLocalId){"); + in(); + { + newLine(); + write("_mailbox[0] = _value;"); + out(); + newLine(); + } + write("}"); + newLine(); + write("barrier(CLK_LOCAL_MEM_FENCE);"); + newLine(); + write("int value = _mailbox[0];"); + newLine(); + write("barrier(CLK_LOCAL_MEM_FENCE);"); + newLine(); + write("return value;"); + out(); + newLine(); + } + write("}"); + newLine(); + } + public final static String __local = "__local"; public final static String __global = "__global"; @@ -531,6 +605,18 @@ public void writePragma(String _name, boolean _enable) { newLine(); } + boolean usesLocalExchange = usesLocalMessagePassing(_entryPoint, "localExchange"); + boolean usesLocalBroadcast = usesLocalMessagePassing(_entryPoint, "localBroadcast"); + if (usesLocalExchange || usesLocalBroadcast) { + if (usesLocalExchange) { + writeLocalExchangeHelper(); + } + if (usesLocalBroadcast) { + writeLocalBroadcastHelper(); + } + newLine(); + } + if (Config.enableDoubles || _entryPoint.requiresDoublePragma()) { writePragma("cl_khr_fp64", true); newLine(); diff --git a/src/test/java/com/aparapi/codegen/test/LocalMessagePassing.java b/src/test/java/com/aparapi/codegen/test/LocalMessagePassing.java new file mode 100644 index 00000000..2888ca03 --- /dev/null +++ b/src/test/java/com/aparapi/codegen/test/LocalMessagePassing.java @@ -0,0 +1,32 @@ +/** + * Copyright (c) 2016 - 2018 Syncleus, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.aparapi.codegen.test; + +import com.aparapi.Kernel; + +public class LocalMessagePassing extends Kernel { + + @Local int[] mailbox = new int[4]; + int[] values = new int[4]; + + @Override public void run() { + int localId = getLocalId(); + int value = values[getGlobalId()]; + int exchangeValue = localExchange(mailbox, (localId + 1) & 3, value); + int broadcastValue = localBroadcast(mailbox, 0, value); + values[getGlobalId()] = exchangeValue + broadcastValue; + } +} diff --git a/src/test/java/com/aparapi/codegen/test/LocalMessagePassingTest.java b/src/test/java/com/aparapi/codegen/test/LocalMessagePassingTest.java new file mode 100644 index 00000000..3566813a --- /dev/null +++ b/src/test/java/com/aparapi/codegen/test/LocalMessagePassingTest.java @@ -0,0 +1,79 @@ +/** + * Copyright (c) 2016 - 2018 Syncleus, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.aparapi.codegen.test; + +import org.junit.Test; + +public class LocalMessagePassingTest extends com.aparapi.codegen.CodeGenJUnitBase { + private static final String[] expectedOpenCL = { + "int localExchange(__local int *_mailbox, int _targetLocalId, int _value){\n" + + " _mailbox[get_local_id(0)] = _value;\n" + + " barrier(CLK_LOCAL_MEM_FENCE);\n" + + " int value = _mailbox[_targetLocalId];\n" + + " barrier(CLK_LOCAL_MEM_FENCE);\n" + + " return value;\n" + + " }\n" + + " int localBroadcast(__local int *_mailbox, int _sourceLocalId, int _value){\n" + + " if (get_local_id(0)==_sourceLocalId){\n" + + " _mailbox[0] = _value;\n" + + " }\n" + + " barrier(CLK_LOCAL_MEM_FENCE);\n" + + " int value = _mailbox[0];\n" + + " barrier(CLK_LOCAL_MEM_FENCE);\n" + + " return value;\n" + + " }\n" + + "\n" + + " typedef struct This_s{\n" + + " __global int *values;\n" + + " __local int *mailbox;\n" + + " int passid;\n" + + " }This;\n" + + " int get_pass_id(This *this){\n" + + " return this->passid;\n" + + " }\n" + + " __kernel void run(\n" + + " __global int *values, \n" + + " __local int *mailbox, \n" + + " int passid\n" + + " ){\n" + + " This thisStruct;\n" + + " This* this=&thisStruct;\n" + + " this->values = values;\n" + + " this->mailbox = mailbox;\n" + + " this->passid = passid;\n" + + " {\n" + + " int localId = get_local_id(0);\n" + + " int value = this->values[get_global_id(0)];\n" + + " int exchangeValue = localExchange(this->mailbox, ((localId + 1) & 3), value);\n" + + " int broadcastValue = localBroadcast(this->mailbox, 0, value);\n" + + " this->values[get_global_id(0)] = exchangeValue + broadcastValue;\n" + + " return;\n" + + " }\n" + + " }\n" + + " " + }; + private static final Class expectedException = null; + + @Test + public void LocalMessagePassingTest() { + test(com.aparapi.codegen.test.LocalMessagePassing.class, expectedException, expectedOpenCL); + } + + @Test + public void LocalMessagePassingTestWorksWithCaching() { + test(com.aparapi.codegen.test.LocalMessagePassing.class, expectedException, expectedOpenCL); + } +} diff --git a/src/test/java/com/aparapi/runtime/LocalMessagePassingTest.java b/src/test/java/com/aparapi/runtime/LocalMessagePassingTest.java new file mode 100644 index 00000000..e91a8df3 --- /dev/null +++ b/src/test/java/com/aparapi/runtime/LocalMessagePassingTest.java @@ -0,0 +1,52 @@ +/** + * Copyright (c) 2016 - 2018 Syncleus, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.aparapi.runtime; + +import static org.junit.Assert.assertArrayEquals; + +import org.junit.Test; + +import com.aparapi.Kernel; +import com.aparapi.Range; +import com.aparapi.device.JavaDevice; + +public class LocalMessagePassingTest { + + @Test + public void localExchangeAndBroadcastWorkOnJavaThreadPool() { + final int[] values = {10, 20, 30, 40}; + + Kernel kernel = new Kernel() { + @Local int[] mailbox = new int[4]; + + @Override public void run() { + int localId = getLocalId(); + int value = values[getGlobalId()]; + int exchangeValue = localExchange(mailbox, (localId + 1) & 3, value); + int broadcastValue = localBroadcast(mailbox, 0, value); + values[getGlobalId()] = exchangeValue + broadcastValue; + } + }; + + try { + kernel.execute(Range.create(JavaDevice.THREAD_POOL, 4, 4)); + } finally { + kernel.dispose(); + } + + assertArrayEquals(new int[] {30, 40, 50, 20}, values); + } +}