diff --git a/src/main/java/com/aparapi/Kernel.java b/src/main/java/com/aparapi/Kernel.java
index 4b9686db..a2f7f62f 100644
--- a/src/main/java/com/aparapi/Kernel.java
+++ b/src/main/java/com/aparapi/Kernel.java
@@ -2452,6 +2452,56 @@ protected final int atomicXor(AtomicInteger p, int val) {
return p.getAndAccumulate(val, xorOperator);
}
+ /**
+ * Exchanges an integer value with another work-item in the same work-group using a caller supplied local-memory
+ * buffer.
+ *
+ * Every work-item in the work-group must call this method, and _mailbox must be a {@link Local}
+ * buffer, or a buffer named with {@link #LOCAL_SUFFIX}, with at least getLocalSize() entries. Each
+ * work-item publishes _value at its local id, waits for the work-group, reads the value published by
+ * _targetLocalId, and waits again before returning. As with {@link #localBarrier()}, all work-items
+ * must reach the same calls in the same order.
+ *
+ * @param _mailbox local-memory buffer used as a mailbox between work-items
+ * @param _targetLocalId local id whose value should be read by the current work-item
+ * @param _value value published by the current work-item
+ * @return value published by _targetLocalId
+ */
+ @OpenCLMapping
+ @Experimental
+ protected final int localExchange(int[] _mailbox, int _targetLocalId, int _value) {
+ _mailbox[getLocalId()] = _value;
+ localBarrier();
+ final int value = _mailbox[_targetLocalId];
+ localBarrier();
+ return value;
+ }
+
+ /**
+ * Broadcasts an integer value from one work-item to the rest of the current work-group.
+ *
+ * Every work-item in the work-group must call this method, and _mailbox must be a {@link Local}
+ * buffer, or a buffer named with {@link #LOCAL_SUFFIX}, with at least one entry. The work-item whose local id equals
+ * _sourceLocalId writes _value, then all work-items read and return that value. As with
+ * {@link #localBarrier()}, all work-items must reach the same calls in the same order.
+ *
+ * @param _mailbox local-memory buffer used to publish the broadcast value
+ * @param _sourceLocalId local id of the work-item providing the broadcast value
+ * @param _value value published by the source work-item
+ * @return value published by _sourceLocalId
+ */
+ @OpenCLMapping
+ @Experimental
+ protected final int localBroadcast(int[] _mailbox, int _sourceLocalId, int _value) {
+ if (getLocalId() == _sourceLocalId) {
+ _mailbox[0] = _value;
+ }
+ localBarrier();
+ final int value = _mailbox[0];
+ localBarrier();
+ return value;
+ }
+
/**
* Wait for all kernels in the current work group to rendezvous at this call before continuing execution.
* It will also enforce memory ordering, such that modifications made by each thread in the work-group, to the memory,
diff --git a/src/main/java/com/aparapi/internal/writer/KernelWriter.java b/src/main/java/com/aparapi/internal/writer/KernelWriter.java
index 0a72a33f..6073f837 100644
--- a/src/main/java/com/aparapi/internal/writer/KernelWriter.java
+++ b/src/main/java/com/aparapi/internal/writer/KernelWriter.java
@@ -318,6 +318,80 @@ public void writePragma(String _name, boolean _enable) {
newLine();
}
+ private boolean usesLocalMessagePassing(Entrypoint _entryPoint, String _methodName) {
+ if (usesLocalMessagePassing(_entryPoint.getMethodModel(), _methodName)) {
+ return true;
+ }
+ for (final MethodModel methodModel : _entryPoint.getCalledMethods()) {
+ if (usesLocalMessagePassing(methodModel, _methodName)) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ private boolean usesLocalMessagePassing(MethodModel _methodModel, String _methodName) {
+ for (final MethodCall methodCall : _methodModel.getMethodCalls()) {
+ final MethodEntry methodEntry = methodCall.getConstantPoolMethodEntry();
+ final String methodName = methodEntry.getNameAndTypeEntry().getNameUTF8Entry().getUTF8();
+ final String methodSignature = methodEntry.getNameAndTypeEntry().getDescriptorUTF8Entry().getUTF8();
+ if (Kernel.isMappedMethod(methodEntry) && _methodName.equals(methodName) && "([III)I".equals(methodSignature)) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ private void writeLocalExchangeHelper() {
+ write("int localExchange(__local int *_mailbox, int _targetLocalId, int _value){");
+ in();
+ {
+ newLine();
+ write("_mailbox[get_local_id(0)] = _value;");
+ newLine();
+ write("barrier(CLK_LOCAL_MEM_FENCE);");
+ newLine();
+ write("int value = _mailbox[_targetLocalId];");
+ newLine();
+ write("barrier(CLK_LOCAL_MEM_FENCE);");
+ newLine();
+ write("return value;");
+ out();
+ newLine();
+ }
+ write("}");
+ newLine();
+ }
+
+ private void writeLocalBroadcastHelper() {
+ write("int localBroadcast(__local int *_mailbox, int _sourceLocalId, int _value){");
+ in();
+ {
+ newLine();
+ write("if (get_local_id(0)==_sourceLocalId){");
+ in();
+ {
+ newLine();
+ write("_mailbox[0] = _value;");
+ out();
+ newLine();
+ }
+ write("}");
+ newLine();
+ write("barrier(CLK_LOCAL_MEM_FENCE);");
+ newLine();
+ write("int value = _mailbox[0];");
+ newLine();
+ write("barrier(CLK_LOCAL_MEM_FENCE);");
+ newLine();
+ write("return value;");
+ out();
+ newLine();
+ }
+ write("}");
+ newLine();
+ }
+
public final static String __local = "__local";
public final static String __global = "__global";
@@ -531,6 +605,18 @@ public void writePragma(String _name, boolean _enable) {
newLine();
}
+ boolean usesLocalExchange = usesLocalMessagePassing(_entryPoint, "localExchange");
+ boolean usesLocalBroadcast = usesLocalMessagePassing(_entryPoint, "localBroadcast");
+ if (usesLocalExchange || usesLocalBroadcast) {
+ if (usesLocalExchange) {
+ writeLocalExchangeHelper();
+ }
+ if (usesLocalBroadcast) {
+ writeLocalBroadcastHelper();
+ }
+ newLine();
+ }
+
if (Config.enableDoubles || _entryPoint.requiresDoublePragma()) {
writePragma("cl_khr_fp64", true);
newLine();
diff --git a/src/test/java/com/aparapi/codegen/test/LocalMessagePassing.java b/src/test/java/com/aparapi/codegen/test/LocalMessagePassing.java
new file mode 100644
index 00000000..2888ca03
--- /dev/null
+++ b/src/test/java/com/aparapi/codegen/test/LocalMessagePassing.java
@@ -0,0 +1,32 @@
+/**
+ * Copyright (c) 2016 - 2018 Syncleus, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.aparapi.codegen.test;
+
+import com.aparapi.Kernel;
+
+public class LocalMessagePassing extends Kernel {
+
+ @Local int[] mailbox = new int[4];
+ int[] values = new int[4];
+
+ @Override public void run() {
+ int localId = getLocalId();
+ int value = values[getGlobalId()];
+ int exchangeValue = localExchange(mailbox, (localId + 1) & 3, value);
+ int broadcastValue = localBroadcast(mailbox, 0, value);
+ values[getGlobalId()] = exchangeValue + broadcastValue;
+ }
+}
diff --git a/src/test/java/com/aparapi/codegen/test/LocalMessagePassingTest.java b/src/test/java/com/aparapi/codegen/test/LocalMessagePassingTest.java
new file mode 100644
index 00000000..3566813a
--- /dev/null
+++ b/src/test/java/com/aparapi/codegen/test/LocalMessagePassingTest.java
@@ -0,0 +1,79 @@
+/**
+ * Copyright (c) 2016 - 2018 Syncleus, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.aparapi.codegen.test;
+
+import org.junit.Test;
+
+public class LocalMessagePassingTest extends com.aparapi.codegen.CodeGenJUnitBase {
+ private static final String[] expectedOpenCL = {
+ "int localExchange(__local int *_mailbox, int _targetLocalId, int _value){\n" +
+ " _mailbox[get_local_id(0)] = _value;\n" +
+ " barrier(CLK_LOCAL_MEM_FENCE);\n" +
+ " int value = _mailbox[_targetLocalId];\n" +
+ " barrier(CLK_LOCAL_MEM_FENCE);\n" +
+ " return value;\n" +
+ " }\n" +
+ " int localBroadcast(__local int *_mailbox, int _sourceLocalId, int _value){\n" +
+ " if (get_local_id(0)==_sourceLocalId){\n" +
+ " _mailbox[0] = _value;\n" +
+ " }\n" +
+ " barrier(CLK_LOCAL_MEM_FENCE);\n" +
+ " int value = _mailbox[0];\n" +
+ " barrier(CLK_LOCAL_MEM_FENCE);\n" +
+ " return value;\n" +
+ " }\n" +
+ "\n" +
+ " typedef struct This_s{\n" +
+ " __global int *values;\n" +
+ " __local int *mailbox;\n" +
+ " int passid;\n" +
+ " }This;\n" +
+ " int get_pass_id(This *this){\n" +
+ " return this->passid;\n" +
+ " }\n" +
+ " __kernel void run(\n" +
+ " __global int *values, \n" +
+ " __local int *mailbox, \n" +
+ " int passid\n" +
+ " ){\n" +
+ " This thisStruct;\n" +
+ " This* this=&thisStruct;\n" +
+ " this->values = values;\n" +
+ " this->mailbox = mailbox;\n" +
+ " this->passid = passid;\n" +
+ " {\n" +
+ " int localId = get_local_id(0);\n" +
+ " int value = this->values[get_global_id(0)];\n" +
+ " int exchangeValue = localExchange(this->mailbox, ((localId + 1) & 3), value);\n" +
+ " int broadcastValue = localBroadcast(this->mailbox, 0, value);\n" +
+ " this->values[get_global_id(0)] = exchangeValue + broadcastValue;\n" +
+ " return;\n" +
+ " }\n" +
+ " }\n" +
+ " "
+ };
+ private static final Class extends com.aparapi.internal.exception.AparapiException> expectedException = null;
+
+ @Test
+ public void LocalMessagePassingTest() {
+ test(com.aparapi.codegen.test.LocalMessagePassing.class, expectedException, expectedOpenCL);
+ }
+
+ @Test
+ public void LocalMessagePassingTestWorksWithCaching() {
+ test(com.aparapi.codegen.test.LocalMessagePassing.class, expectedException, expectedOpenCL);
+ }
+}
diff --git a/src/test/java/com/aparapi/runtime/LocalMessagePassingTest.java b/src/test/java/com/aparapi/runtime/LocalMessagePassingTest.java
new file mode 100644
index 00000000..e91a8df3
--- /dev/null
+++ b/src/test/java/com/aparapi/runtime/LocalMessagePassingTest.java
@@ -0,0 +1,52 @@
+/**
+ * Copyright (c) 2016 - 2018 Syncleus, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.aparapi.runtime;
+
+import static org.junit.Assert.assertArrayEquals;
+
+import org.junit.Test;
+
+import com.aparapi.Kernel;
+import com.aparapi.Range;
+import com.aparapi.device.JavaDevice;
+
+public class LocalMessagePassingTest {
+
+ @Test
+ public void localExchangeAndBroadcastWorkOnJavaThreadPool() {
+ final int[] values = {10, 20, 30, 40};
+
+ Kernel kernel = new Kernel() {
+ @Local int[] mailbox = new int[4];
+
+ @Override public void run() {
+ int localId = getLocalId();
+ int value = values[getGlobalId()];
+ int exchangeValue = localExchange(mailbox, (localId + 1) & 3, value);
+ int broadcastValue = localBroadcast(mailbox, 0, value);
+ values[getGlobalId()] = exchangeValue + broadcastValue;
+ }
+ };
+
+ try {
+ kernel.execute(Range.create(JavaDevice.THREAD_POOL, 4, 4));
+ } finally {
+ kernel.dispose();
+ }
+
+ assertArrayEquals(new int[] {30, 40, 50, 20}, values);
+ }
+}