Java tutorial
/* * * * Copyright 2015 Skymind,Inc. * * * * Licensed under the Apache License, Version 2.0 (the "License"); * * you may not use this file except in compliance with the License. * * You may obtain a copy of the License at * * * * http://www.apache.org/licenses/LICENSE-2.0 * * * * Unless required by applicable law or agreed to in writing, software * * distributed under the License is distributed on an "AS IS" BASIS, * * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * * See the License for the specific language governing permissions and * * limitations under the License. * * */ package org.nd4j.linalg.jcublas.buffer; import com.google.common.collect.HashBasedTable; import com.google.common.collect.Table; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import jcuda.Pointer; import jcuda.jcublas.JCublas2; import org.apache.commons.lang3.tuple.Triple; import org.nd4j.linalg.api.blas.BlasBufferUtil; import org.nd4j.linalg.api.buffer.BaseDataBuffer; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.complex.IComplexDouble; import org.nd4j.linalg.api.complex.IComplexFloat; import org.nd4j.linalg.api.complex.IComplexNDArray; import org.nd4j.linalg.api.complex.IComplexNumber; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.jcublas.buffer.allocation.HostDevicePointer; import org.nd4j.linalg.jcublas.buffer.allocation.MemoryStrategy; import org.nd4j.linalg.jcublas.complex.CudaComplexConversion; import org.nd4j.linalg.jcublas.context.ContextHolder; import org.nd4j.linalg.jcublas.context.CudaContext; import org.nd4j.linalg.jcublas.util.PointerUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.IOException; import java.lang.ref.WeakReference; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.util.*; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; /** * Base class for a data buffer * * @author Adam Gibson */ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCudaBuffer { static AtomicLong allocated = new AtomicLong(); static AtomicLong totalAllocated = new AtomicLong(); private static Logger log = LoggerFactory.getLogger(BaseCudaDataBuffer.class); /** * Pointers to contexts covers this buffer on the gpu at offset 0 * for each thread. * * The column key is for offsets. If we only have buffer one device allocation per thread * we will clobber anything that is already allocated on the gpu. * * This also allows us to make a simplifying assumption about how to allocate the data as follows: * * Always allocate for offset zero by default. This allows us to reuse the same pointer with an offset * for each extra allocations (say for row wise operations) * * This also prevents duplicate uploads to the gpu. * Typical usage here: * DevicePointerInfo devicePointerInfo = pointersToContexts.get(Thread.currentThread().getName(),Triple.of(offset,length,stride)); */ protected transient Table<String, Triple<Integer, Integer, Integer>, DevicePointerInfo> pointersToContexts = HashBasedTable .create(); protected AtomicBoolean modified = new AtomicBoolean(false); protected Collection<String> referencing = Collections.synchronizedSet(new HashSet<String>()); protected transient WeakReference<DataBuffer> ref; protected AtomicBoolean freed = new AtomicBoolean(false); private Map<String, Boolean> copied = new ConcurrentHashMap<>(); public BaseCudaDataBuffer(ByteBuf buf, int length) { super(buf, length); // pointersToContexts = new SynchronizedTable<>(pointersToContexts); } public BaseCudaDataBuffer(float[] data, boolean copy) { super(data, copy); // pointersToContexts = new SynchronizedTable<>(pointersToContexts); } public BaseCudaDataBuffer(double[] data, boolean copy) { super(data, copy); // pointersToContexts = new SynchronizedTable<>(pointersToContexts); } public BaseCudaDataBuffer(int[] data, boolean copy) { super(data, copy); } /** * Base constructor * * @param length the length of the buffer * @param elementSize the size of each element */ public BaseCudaDataBuffer(int length, int elementSize) { super(length, elementSize); } public BaseCudaDataBuffer(int length) { super(length); } public BaseCudaDataBuffer(float[] data) { super(data); } public BaseCudaDataBuffer(int[] data) { super(data); } public BaseCudaDataBuffer(double[] data) { super(data); } public BaseCudaDataBuffer(byte[] data, int length) { super(data, length); } public BaseCudaDataBuffer(ByteBuffer buffer, int length) { super(buffer, length); } @Override protected void setNioBuffer() { wrappedBuffer = ByteBuffer.allocateDirect(elementSize * length); wrappedBuffer.order(ByteOrder.nativeOrder()); } @Override public void copyAtStride(DataBuffer buf, int n, int stride, int yStride, int offset, int yOffset) { super.copyAtStride(buf, n, stride, yStride, offset, yOffset); MemoryStrategy strategy = ContextHolder.getInstance().getMemoryStrategy(); strategy.setData(buf, offset, stride, length()); } @Override public boolean copied(String name) { Boolean copied = this.copied.get(name); if (copied == null) return false; return this.copied.get(name); } @Override public void setCopied(String name) { copied.put(name, true); } @Override public AllocationMode allocationMode() { return allocationMode; } @Override public ByteBuffer getHostBuffer() { return wrappedBuffer; } @Override public void setHostBuffer(ByteBuffer hostBuffer) { this.wrappedBuffer = hostBuffer; } @Override public Pointer getHostPointer() { throw new UnsupportedOperationException(); } @Override public Pointer getHostPointer(int offset) { throw new UnsupportedOperationException(); } @Override public void removeReferencing(String id) { referencing.remove(id); } @Override public Collection<String> references() { return referencing; } @Override public int getElementSize() { return elementSize; } @Override public void addReferencing(String id) { referencing.add(id); } @Override public void put(int i, IComplexNumber result) { modified.set(true); if (dataType() == DataBuffer.Type.FLOAT) { JCublas2.cublasSetVector((int) length(), getElementSize(), PointerUtil.getPointer(CudaComplexConversion.toComplex(result.asFloat())), 1, getHostPointer(), 1); } else { JCublas2.cublasSetVector((int) length(), getElementSize(), PointerUtil.getPointer(CudaComplexConversion.toComplexDouble(result.asDouble())), 1, getHostPointer(), 1); } } @Override public Pointer getDevicePointer(int stride, int offset, int length) { String name = Thread.currentThread().getName(); DevicePointerInfo devicePointerInfo = pointersToContexts.get(name, Triple.of(offset, length, 1)); if (devicePointerInfo == null) { int devicePointerLength = getElementSize() * length; allocated.addAndGet(devicePointerLength); totalAllocated.addAndGet(devicePointerLength); log.trace("Allocating {} bytes, total: {}, overall: {}", devicePointerLength, allocated.get(), totalAllocated); if (devicePointerInfo == null) { /** * Add zero first no matter what. * Allocate the whole buffer on the gpu * and use offsets for any other pointers that come in. * This will allow us to set device pointers with offsets * * with no extra allocation. * * Notice here we ignore the length of the actual array. * * We are going to allocate the whole buffer on the gpu only once. * */ if (!pointersToContexts.contains(name, Triple.of(0, this.length, 1))) { MemoryStrategy strat = ContextHolder.getInstance().getConf().getMemoryStrategy(); devicePointerInfo = (DevicePointerInfo) strat.alloc(this, 1, 0, this.length, true); pointersToContexts.put(name, Triple.of(0, this.length, 1), devicePointerInfo); } if (offset > 0 || length < length()) { /** * Store the length for the offset of the pointer. * Return the original pointer with an offset * (these pointers can't be reused?) * * With the device pointer info, * we want to store the original pointer. * When retrieving the vector from the gpu later, * we will use the recorded offset. * * Due to gpu instability (please correct me if I'm wrong here) * we can't seem to reuse the pointers with the offset specified, * therefore it is desirable to recreate this pointer later. * * This will prevent extra allocation as well * as inform the length for retrieving data from the gpu * for this particular offset and buffer. * */ HostDevicePointer zero = pointersToContexts.get(name, Triple.of(0, length, 1)).getPointers(); HostDevicePointer ret = new HostDevicePointer( zero.getHostPointer().withByteOffset(offset * getElementSize()), zero.getDevicePointer().withByteOffset(offset * getElementSize())); devicePointerInfo = new DevicePointerInfo(ret, length, stride, offset, false); pointersToContexts.put(name, Triple.of(offset, length, stride), devicePointerInfo); return ret.getDevicePointer(); } } freed.set(false); } /** * Return the device pointer with the specified offset. * Regardless of whether the device pointer has been allocated, * we need to return with it respect to the specified array * not the array's underlying buffer. */ if (offset > 0) return devicePointerInfo.getPointers().getDevicePointer(); else return devicePointerInfo.getPointers().getDevicePointer(); } @Override public Pointer getDevicePointer(INDArray arr, int stride, int offset, int length) { String name = Thread.currentThread().getName(); DevicePointerInfo devicePointerInfo = pointersToContexts.get(name, Triple.of(offset, length, stride)); if (devicePointerInfo == null) { int devicePointerLength = getElementSize() * length; allocated.addAndGet(devicePointerLength); totalAllocated.addAndGet(devicePointerLength); log.trace("Allocating {} bytes, total: {}, overall: {}", devicePointerLength, allocated.get(), totalAllocated); //check its the same object if (arr.data() != this) { throw new IllegalArgumentException( "Unable to get pointer for array that doesn't have this as the buffer"); } int compareLength = arr instanceof IComplexNDArray ? arr.length() * 2 : arr.length(); /** * Add zero first no matter what. * Allocate the whole buffer on the gpu * and use offsets for any other pointers that come in. * This will allow us to set device pointers with offsets * * with no extra allocation. * * Notice here we ignore the length of the actual array. * * We are going to allocate the whole buffer on the gpu only once. * */ if (!pointersToContexts.contains(name, Triple.of(0, this.length, 1))) { devicePointerInfo = (DevicePointerInfo) ContextHolder.getInstance().getConf().getMemoryStrategy() .alloc(this, 1, 0, this.length, true); pointersToContexts.put(name, Triple.of(0, this.length, 1), devicePointerInfo); } if (offset > 0) { /** * Store the length for the offset of the pointer. * Return the original pointer with an offset * (these pointers can't be reused?) * * With the device pointer info, * we want to store the original pointer. * When retrieving the vector from the gpu later, * we will use the recorded offset. * * Due to gpu instability (please correct me if I'm wrong here) * we can't seem to reuse the pointers with the offset specified, * therefore it is desirable to recreate this pointer later. * * This will prevent extra allocation as well * as inform the length for retrieving data from the gpu * for this particular offset and buffer. * */ DevicePointerInfo info2 = pointersToContexts.get(name, Triple.of(0, this.length, 1)); if (info2 == null) throw new IllegalStateException( "No pointer found for name " + name + " and offset/length " + offset + " / " + length); HostDevicePointer zero = info2.getPointers(); HostDevicePointer retOffset = new HostDevicePointer( zero.getHostPointer().withByteOffset(offset * getElementSize()), zero.getDevicePointer().withByteOffset(offset * getElementSize())); Pointer ret = retOffset.getDevicePointer(); devicePointerInfo = new DevicePointerInfo(retOffset, length, stride, offset, false); pointersToContexts.put(name, Triple.of(offset, compareLength, stride), devicePointerInfo); return ret; } else if (offset == 0 && compareLength < arr.data().length()) { DevicePointerInfo info2 = pointersToContexts.get(name, Triple.of(0, this.length, 1)); if (info2 == null) { throw new IllegalStateException( "No pointer found for name " + name + " and offset/length " + offset + " / " + length); } DevicePointerInfo info3 = new DevicePointerInfo(info2.getPointers(), this.length, stride, arr.offset(), false); int compareLength2 = arr instanceof IComplexNDArray ? arr.length() * 2 : arr.length(); /** * Need a pointer that * points at the buffer but doesnt extend all the way to the end. * This is for data like the first row of a matrix * that has zero offset but does not extend all the way to the end of the buffer. */ pointersToContexts.put(name, Triple.of(offset, compareLength2, stride), info3); return info3.getPointers().getDevicePointer(); } freed.set(false); } /** * Return the device pointer with the specified offset. * Regardless of whether the device pointer has been allocated, * we need to return with it respect to the specified array * not the array's underlying buffer. */ if (devicePointerInfo == null && offset == 0 && length < length()) { DevicePointerInfo origin = pointersToContexts.get(Thread.currentThread().getName(), Triple.of(0, length(), 1)); DevicePointerInfo newInfo = new DevicePointerInfo(origin.getPointers(), length, stride, 0, false); return newInfo.getPointers().getDevicePointer(); } return devicePointerInfo.getPointers().getDevicePointer().withByteOffset(offset * getElementSize()); } @Override public void set(Pointer pointer) { modified.set(true); if (dataType() == DataBuffer.Type.DOUBLE) { JCublas2.cublasDcopy(ContextHolder.getInstance().getHandle(), length(), pointer, 1, getHostPointer(), 1); } else { JCublas2.cublasScopy(ContextHolder.getInstance().getHandle(), length(), pointer, 1, getHostPointer(), 1); } } private void copyOneElement(int i, double val) { if (pointersToContexts != null) for (DevicePointerInfo info : pointersToContexts.values()) { if (dataType() == Type.FLOAT) JCublas2.cublasSetVector(1, getElementSize(), Pointer.to(new float[] { (float) val }), 1, info.getPointers().getDevicePointer().withByteOffset(getElementSize() * i), 1); else JCublas2.cublasSetVector(1, getElementSize(), Pointer.to(new double[] { val }), 1, info.getPointers().getDevicePointer().withByteOffset(getElementSize() * i), 1); } } @Override public void put(int i, float element) { super.put(i, element); copyOneElement(i, element); } @Override public void put(int i, double element) { super.put(i, element); copyOneElement(i, element); } @Override public IComplexFloat getComplexFloat(int i) { return Nd4j.createFloat(getFloat(i), getFloat(i + 1)); } @Override public IComplexDouble getComplexDouble(int i) { return Nd4j.createDouble(getDouble(i), getDouble(i + 1)); } @Override public IComplexNumber getComplex(int i) { return dataType() == DataBuffer.Type.FLOAT ? getComplexFloat(i) : getComplexDouble(i); } /** * Set an individual element * * @param index the index of the element * @param from the element to get data from */ protected void set(int index, int length, Pointer from, int inc) { modified.set(true); int offset = getElementSize() * index; if (offset >= length() * getElementSize()) throw new IllegalArgumentException( "Illegal offset " + offset + " with index of " + index + " and length " + length()); JCublas2.cublasSetVectorAsync(length, getElementSize(), from, inc, getHostPointer().withByteOffset(offset), 1, ContextHolder.getInstance().getCudaStream()); ContextHolder.getInstance().setContext(); } /** * Set an individual element * * @param index the index of the element * @param from the element to get data from */ protected void set(int index, int length, Pointer from) { set(index, length, from, 1); } @Override public void assign(DataBuffer data) { JCudaBuffer buf = (JCudaBuffer) data; set(0, buf.getHostPointer()); } /** * Set an individual element * * @param index the index of the element * @param from the element to get data from */ protected void set(int index, Pointer from) { set(index, 1, from); } @Override public boolean freeDevicePointer(int offset, int length) { String name = Thread.currentThread().getName(); DevicePointerInfo devicePointerInfo = pointersToContexts.get(name, offset); //nothing to free, there was no copy. Only the gpu pointer was reused with a different offset. if (offset != 0) pointersToContexts.remove(name, offset); else if (offset == 0 && isPersist) { return true; } else if (devicePointerInfo != null && !freed.get()) { allocated.addAndGet(-devicePointerInfo.getLength()); log.trace("freeing {} bytes, total: {}", devicePointerInfo.getLength(), allocated.get()); ContextHolder.getInstance().getMemoryStrategy().free(this, offset, length); freed.set(true); copied.remove(name); pointersToContexts.remove(name, Triple.of(offset, length, devicePointerInfo.getStride())); return true; } return false; } @Override public synchronized void copyToHost(CudaContext context, int offset, int length, int stride) { DevicePointerInfo devicePointerInfo = pointersToContexts.get(Thread.currentThread().getName(), Triple.of(offset, length, stride)); if (devicePointerInfo == null) throw new IllegalStateException("No pointer found for offset " + offset); //prevent inconsistent pointers if (devicePointerInfo.getOffset() != offset) throw new IllegalStateException("Device pointer offset didn't match specified offset in pointer map"); if (devicePointerInfo != null) { ContextHolder.getInstance().getMemoryStrategy().copyToHost(this, offset, stride, length, null, offset, stride); } else throw new IllegalStateException("No offset found to copy"); //synchronize for the copy to avoid data inconsistencies context.syncOldStream(); } @Override public synchronized void copyToHost(int offset, int length) { DevicePointerInfo devicePointerInfo = pointersToContexts.get(Thread.currentThread().getName(), Triple.of(offset, length, 1)); if (devicePointerInfo == null) throw new IllegalStateException("No pointer found for offset " + offset); //prevent inconsistent pointers if (devicePointerInfo.getOffset() != offset) throw new IllegalStateException("Device pointer offset didn't match specified offset in pointer map"); if (devicePointerInfo != null) { int deviceStride = devicePointerInfo.getStride(); int deviceOffset = devicePointerInfo.getOffset(); int deviceLength = (int) devicePointerInfo.getLength(); if (deviceOffset == 0 && length < length()) { /** * The way the data works out the stride for retrieving the data * should be 1. * * The device stride should be used for resetting the data. * * This is for the edge case where the offset is zero and * the length of the pointer is < the actual buffer length itself. * DevicePointerInfo devicePointerInfo = pointersToContexts.get(Thread.currentThread().getName(),Triple.of(offset,length,1)); */ ContextHolder.getInstance().getMemoryStrategy().copyToHost(this, offset, deviceStride, deviceLength, null, deviceOffset, deviceStride); } else { ContextHolder.getInstance().getMemoryStrategy().copyToHost(this, offset, deviceStride, deviceLength, null, deviceOffset, deviceStride); } } } @Override public void flush() { throw new UnsupportedOperationException(); } @Override public void destroy() { } private void writeObject(java.io.ObjectOutputStream stream) throws IOException { stream.defaultWriteObject(); write(stream); } private void readObject(java.io.ObjectInputStream stream) throws IOException, ClassNotFoundException { doReadObject(stream); copied = new HashMap<>(); pointersToContexts = HashBasedTable.create(); ref = new WeakReference<DataBuffer>(this, Nd4j.bufferRefQueue()); freed = new AtomicBoolean(false); } @Override public Table<String, Triple<Integer, Integer, Integer>, DevicePointerInfo> getPointersToContexts() { return pointersToContexts; } public void setPointersToContexts( Table<String, Triple<Integer, Integer, Integer>, DevicePointerInfo> pointersToContexts) { this.pointersToContexts = pointersToContexts; } @Override public String toString() { StringBuffer sb = new StringBuffer(); sb.append("["); for (int i = 0; i < length(); i++) { sb.append(getDouble(i)); if (i < length() - 1) sb.append(","); } sb.append("]"); return sb.toString(); } }