Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions src/main/java/com/aparapi/Kernel.java
Original file line number Diff line number Diff line change
Expand Up @@ -2452,6 +2452,56 @@ protected final int atomicXor(AtomicInteger p, int val) {
return p.getAndAccumulate(val, xorOperator);
}

/**
* Exchanges an integer value with another work-item in the same work-group using a caller supplied local-memory
* buffer.
* <br>
* Every work-item in the work-group must call this method, and <code>_mailbox</code> must be a {@link Local}
* buffer, or a buffer named with {@link #LOCAL_SUFFIX}, with at least <code>getLocalSize()</code> entries. Each
* work-item publishes <code>_value</code> at its local id, waits for the work-group, reads the value published by
* <code>_targetLocalId</code>, and waits again before returning. As with {@link #localBarrier()}, all work-items
* must reach the same calls in the same order.
*
* @param _mailbox local-memory buffer used as a mailbox between work-items
* @param _targetLocalId local id whose value should be read by the current work-item
* @param _value value published by the current work-item
* @return value published by <code>_targetLocalId</code>
*/
@OpenCLMapping
@Experimental
protected final int localExchange(int[] _mailbox, int _targetLocalId, int _value) {
_mailbox[getLocalId()] = _value;
localBarrier();
final int value = _mailbox[_targetLocalId];
localBarrier();
return value;
}

/**
* Broadcasts an integer value from one work-item to the rest of the current work-group.
* <br>
* Every work-item in the work-group must call this method, and <code>_mailbox</code> must be a {@link Local}
* buffer, or a buffer named with {@link #LOCAL_SUFFIX}, with at least one entry. The work-item whose local id equals
* <code>_sourceLocalId</code> writes <code>_value</code>, then all work-items read and return that value. As with
* {@link #localBarrier()}, all work-items must reach the same calls in the same order.
*
* @param _mailbox local-memory buffer used to publish the broadcast value
* @param _sourceLocalId local id of the work-item providing the broadcast value
* @param _value value published by the source work-item
* @return value published by <code>_sourceLocalId</code>
*/
@OpenCLMapping
@Experimental
protected final int localBroadcast(int[] _mailbox, int _sourceLocalId, int _value) {
if (getLocalId() == _sourceLocalId) {
_mailbox[0] = _value;
}
localBarrier();
final int value = _mailbox[0];
localBarrier();
return value;
}

/**
* Wait for all kernels in the current work group to rendezvous at this call before continuing execution.<br>
* It will also enforce memory ordering, such that modifications made by each thread in the work-group, to the memory,
Expand Down
86 changes: 86 additions & 0 deletions src/main/java/com/aparapi/internal/writer/KernelWriter.java
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,80 @@ public void writePragma(String _name, boolean _enable) {
newLine();
}

private boolean usesLocalMessagePassing(Entrypoint _entryPoint, String _methodName) {
if (usesLocalMessagePassing(_entryPoint.getMethodModel(), _methodName)) {
return true;
}
for (final MethodModel methodModel : _entryPoint.getCalledMethods()) {
if (usesLocalMessagePassing(methodModel, _methodName)) {
return true;
}
}
return false;
}

private boolean usesLocalMessagePassing(MethodModel _methodModel, String _methodName) {
for (final MethodCall methodCall : _methodModel.getMethodCalls()) {
final MethodEntry methodEntry = methodCall.getConstantPoolMethodEntry();
final String methodName = methodEntry.getNameAndTypeEntry().getNameUTF8Entry().getUTF8();
final String methodSignature = methodEntry.getNameAndTypeEntry().getDescriptorUTF8Entry().getUTF8();
if (Kernel.isMappedMethod(methodEntry) && _methodName.equals(methodName) && "([III)I".equals(methodSignature)) {
return true;
}
}
return false;
}

private void writeLocalExchangeHelper() {
write("int localExchange(__local int *_mailbox, int _targetLocalId, int _value){");
in();
{
newLine();
write("_mailbox[get_local_id(0)] = _value;");
newLine();
write("barrier(CLK_LOCAL_MEM_FENCE);");
newLine();
write("int value = _mailbox[_targetLocalId];");
newLine();
write("barrier(CLK_LOCAL_MEM_FENCE);");
newLine();
write("return value;");
out();
newLine();
}
write("}");
newLine();
}

private void writeLocalBroadcastHelper() {
write("int localBroadcast(__local int *_mailbox, int _sourceLocalId, int _value){");
in();
{
newLine();
write("if (get_local_id(0)==_sourceLocalId){");
in();
{
newLine();
write("_mailbox[0] = _value;");
out();
newLine();
}
write("}");
newLine();
write("barrier(CLK_LOCAL_MEM_FENCE);");
newLine();
write("int value = _mailbox[0];");
newLine();
write("barrier(CLK_LOCAL_MEM_FENCE);");
newLine();
write("return value;");
out();
newLine();
}
write("}");
newLine();
}

public final static String __local = "__local";

public final static String __global = "__global";
Expand Down Expand Up @@ -531,6 +605,18 @@ public void writePragma(String _name, boolean _enable) {
newLine();
}

boolean usesLocalExchange = usesLocalMessagePassing(_entryPoint, "localExchange");
boolean usesLocalBroadcast = usesLocalMessagePassing(_entryPoint, "localBroadcast");
if (usesLocalExchange || usesLocalBroadcast) {
if (usesLocalExchange) {
writeLocalExchangeHelper();
}
if (usesLocalBroadcast) {
writeLocalBroadcastHelper();
}
newLine();
}

if (Config.enableDoubles || _entryPoint.requiresDoublePragma()) {
writePragma("cl_khr_fp64", true);
newLine();
Expand Down
32 changes: 32 additions & 0 deletions src/test/java/com/aparapi/codegen/test/LocalMessagePassing.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/**
* Copyright (c) 2016 - 2018 Syncleus, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.aparapi.codegen.test;

import com.aparapi.Kernel;

public class LocalMessagePassing extends Kernel {

@Local int[] mailbox = new int[4];
int[] values = new int[4];

@Override public void run() {
int localId = getLocalId();
int value = values[getGlobalId()];
int exchangeValue = localExchange(mailbox, (localId + 1) & 3, value);
int broadcastValue = localBroadcast(mailbox, 0, value);
values[getGlobalId()] = exchangeValue + broadcastValue;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/**
* Copyright (c) 2016 - 2018 Syncleus, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.aparapi.codegen.test;

import org.junit.Test;

public class LocalMessagePassingTest extends com.aparapi.codegen.CodeGenJUnitBase {
private static final String[] expectedOpenCL = {
"int localExchange(__local int *_mailbox, int _targetLocalId, int _value){\n" +
" _mailbox[get_local_id(0)] = _value;\n" +
" barrier(CLK_LOCAL_MEM_FENCE);\n" +
" int value = _mailbox[_targetLocalId];\n" +
" barrier(CLK_LOCAL_MEM_FENCE);\n" +
" return value;\n" +
" }\n" +
" int localBroadcast(__local int *_mailbox, int _sourceLocalId, int _value){\n" +
" if (get_local_id(0)==_sourceLocalId){\n" +
" _mailbox[0] = _value;\n" +
" }\n" +
" barrier(CLK_LOCAL_MEM_FENCE);\n" +
" int value = _mailbox[0];\n" +
" barrier(CLK_LOCAL_MEM_FENCE);\n" +
" return value;\n" +
" }\n" +
"\n" +
" typedef struct This_s{\n" +
" __global int *values;\n" +
" __local int *mailbox;\n" +
" int passid;\n" +
" }This;\n" +
" int get_pass_id(This *this){\n" +
" return this->passid;\n" +
" }\n" +
" __kernel void run(\n" +
" __global int *values, \n" +
" __local int *mailbox, \n" +
" int passid\n" +
" ){\n" +
" This thisStruct;\n" +
" This* this=&thisStruct;\n" +
" this->values = values;\n" +
" this->mailbox = mailbox;\n" +
" this->passid = passid;\n" +
" {\n" +
" int localId = get_local_id(0);\n" +
" int value = this->values[get_global_id(0)];\n" +
" int exchangeValue = localExchange(this->mailbox, ((localId + 1) & 3), value);\n" +
" int broadcastValue = localBroadcast(this->mailbox, 0, value);\n" +
" this->values[get_global_id(0)] = exchangeValue + broadcastValue;\n" +
" return;\n" +
" }\n" +
" }\n" +
" "
};
private static final Class<? extends com.aparapi.internal.exception.AparapiException> expectedException = null;

@Test
public void LocalMessagePassingTest() {
test(com.aparapi.codegen.test.LocalMessagePassing.class, expectedException, expectedOpenCL);
}

@Test
public void LocalMessagePassingTestWorksWithCaching() {
test(com.aparapi.codegen.test.LocalMessagePassing.class, expectedException, expectedOpenCL);
}
}
52 changes: 52 additions & 0 deletions src/test/java/com/aparapi/runtime/LocalMessagePassingTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/**
* Copyright (c) 2016 - 2018 Syncleus, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.aparapi.runtime;

import static org.junit.Assert.assertArrayEquals;

import org.junit.Test;

import com.aparapi.Kernel;
import com.aparapi.Range;
import com.aparapi.device.JavaDevice;

public class LocalMessagePassingTest {

@Test
public void localExchangeAndBroadcastWorkOnJavaThreadPool() {
final int[] values = {10, 20, 30, 40};

Kernel kernel = new Kernel() {
@Local int[] mailbox = new int[4];

@Override public void run() {
int localId = getLocalId();
int value = values[getGlobalId()];
int exchangeValue = localExchange(mailbox, (localId + 1) & 3, value);
int broadcastValue = localBroadcast(mailbox, 0, value);
values[getGlobalId()] = exchangeValue + broadcastValue;
}
};

try {
kernel.execute(Range.create(JavaDevice.THREAD_POOL, 4, 4));
} finally {
kernel.dispose();
}

assertArrayEquals(new int[] {30, 40, 50, 20}, values);
}
}