Java
Introduction
Band provides a Java API to support Android native applications. This document provides a quick overview of how to use the Java API.
Example
The following example shows how to use the Java API to create an engine, register a model, create input and output tensors, and run the model. Link provides a complete example of how to use the Java API to run a model on an image.
import java.util.List;
import java.util.Arrays;
import java.util.ArrayList;
import org.mrsnu.band.BackendType;
import org.mrsnu.band.Band;
import org.mrsnu.band.Buffer;
import org.mrsnu.band.BufferFormat;
import org.mrsnu.band.Config;
import org.mrsnu.band.ConfigBuilder;
import org.mrsnu.band.CpuMaskFlag;
import org.mrsnu.band.Device;
import org.mrsnu.band.Engine;
import org.mrsnu.band.ImageProcessor;
import org.mrsnu.band.ImageProcessorBuilder;
import org.mrsnu.band.Model;
import org.mrsnu.band.Request;
import org.mrsnu.band.SchedulerType;
import org.mrsnu.band.SubgraphPreparationType;
import org.mrsnu.band.Tensor;
Engine engine;
ImageProcessor processor;
Model classifier;
List<Tensor> inputs = new ArrayList<>();
List<Tensor> outputs = new ArrayList<>();
@Override
protected void onCreate(Bundle savedInstanceState) {
// 1. Load the Band library
Band.init();
// 2. Create a configuration for the engine.
b = new ConfigBuilder();
b.addPlannerLogPath("/data/local/tmp/log.json");
b.addSchedulers(
new SchedulerType[]{SchedulerType.HETEROGENEOUS_EARLIEST_FINISH_TIME});
b.addMinimumSubgraphSize(7);
b.addSubgraphPreparationType(SubgraphPreparationType.MERGE_UNIT_SUBGRAPH);
b.addCPUMask(CpuMaskFlag.ALL);
b.addPlannerCPUMask(CpuMaskFlag.PRIMARY);
b.addWorkers(new Device[]{Device.CPU, Device.GPU, Device.DSP, Device.NPU});
b.addWorkerNumThreads(new int[]{1, 1, 1, 1});
b.addWorkerCPUMasks(new CpuMaskFlag[]{CpuMaskFlag.ALL, CpuMaskFlag.ALL,
CpuMaskFlag.ALL, CpuMaskFlag.ALL});
b.addSmoothingFactor(0.1f);
b.addProfileDataPath("/data/local/tmp/profile.json");
b.addOnline(true);
b.addNumWarmups(1);
b.addNumRuns(1);
b.addAllowWorkSteal(true);
b.addAvailabilityCheckIntervalMs(30000);
b.addScheduleWindowSize(10);
// 3. Create an engine with the configuration.
Config config = b.build();
engine = new Engine(config);
// 4. Create a model from a file and register it to the engine.
classifier = new Model(BackendType.TFLITE, "/data/local/tmp/mobilenet_v2_1.0_224_quant.tflite");
engine.registerModel(classifier);
// 5. Create input and output tensors.
inputs.add(engine.createInputTensor(classifier, 0));
outputs.add(engine.createOutputTensor(classifier, 0));
// 6. Create an image processor to preprocess the input image.
ImageProcessorBuilder processorBuilder = new ImageProcessorBuilder();
processorBuilder.addColorSpaceConvert(BufferFormat.RGB);
processorBuilder.addResize(224, 224);
processorBuilder.addDataTypeConvert();
processor = processorBuilder.build();
}
// called when an image is available (e.g. from camera2 API)
private void processImage(Image image) {
// 7. Preprocess the image and run the model.
final Buffer buffer = new Buffer(image.getPlanes(), image.getWidth(), image.getHeight(), BufferFormat.YV12);
preprocessor.process(buffer, inputs.get(0));
// 8. Run the model.
engine.requestSync(classifier, inputs, outputs);
// 9. Postprocess the output.
ByteBuffer rawResults =
outputs.get(0).getData().order(ByteOrder.nativeOrder());
rawResults.rewind();
FloatBuffer results = rawResults.asFloatBuffer();
float[] resultArray = new float[results.remaining()];
int class_index = 0;
float max = 0;
for (int i = 0; i < results.remaining(); i++) {
float value = results.get(i);
if (value > max) {
max = value;
class_index = i;
}
}
// 10. Use the result (e.g., class_index should be 282 (tiger cat) for cat image)
}
Java API classes
The Java API is composed of the following core classes:
- Config
- ConfigBuilder
- Engine
- Model
- Tensor
- Quantization
- Request
Optional classes for preprocessing:
- Buffer
- ImageProcessor
- ImageProcessorBuilder
Enums:
- BackendType
- BufferFormat
- CpuMaskFlag
- DataType
- LogSeverity
- SchedulerType
- SubgraphPreparationType
- TensorType
- WorkerType
We will cover Engine, Model, Tensor, and Buffer classes in detail in the following sections.
API Functions (Engine)
Engine(Config config)
public Engine(Config config)
Creates an engine with the given configuration.
registerModel
public void registerModel(Model model)
Registers a model to the engine. The model must be registered before it can be used to run inference.
getNumInputTensors
public int getNumInputTensors(Model model)
Returns the number of input tensors of the given model.
getNumOutputTensors
public int getNumOutputTensors(Model model)
Returns the number of output tensors of the given model.
requestSync
public void requestSync(Model model, List<Tensor> inputTensors, List<Tensor> outputTensors)
Runs the given model synchronously with the given input and output tensors. The input and output tensors must be created by the engine.
requestAsync
public Request requestAsync(Model model, List<Tensor> inputTensors)
Runs the given model asynchronously with the given input tensors. The input tensors must be created by the engine. Returns a request object that can be used to wait for the result.
requestAsyncBatch
public List<Request> requestAsyncBatch(List<Model> models, List<List<Tensor>> inputTensorLists)
Runs the given models asynchronously with the given input tensors. The input tensors must be created by the engine. Returns a list of request objects that can be used to wait for the results.
wait
public void wait(Request request, List<Tensor> outputTensors)
Waits for the given request to finish and copies the output tensors to the given output tensors. The output tensors must be created by the engine.
createInputTensor
public Tensor createInputTensor(Model model, int index)
Creates an input tensor for the given model and index. The input tensor can be used to run the model.
createOutputTensor
public Tensor createOutputTensor(Model model, int index)
Creates an output tensor for the given model and index. The output tensor can be used to run the model.
API Functions (Model)
Model(BackendType backendType, String filePath)
public Model(BackendType backendType, String filePath)
Creates a model from the given file path. The model should be registered to the engine before it can be used to run inference.
Model(BackendType backendType, ByteBuffer modelBuffer)
public Model(BackendType backendType, ByteBuffer modelBuffer)
Creates a model from the given byte buffer. The model should be registered to the engine before it can be used to run inference.
getSupportedBackends
public List<BackendType> getSupportedBackends()
Returns a list of supported backends for the model.
API Functions (Tensor)
getType
public DataType getType()
Returns the data type of the tensor.
setType
public void setType(DataType dataType)
Sets the data type of the tensor.
getData
public ByteBuffer getData()
Returns the data buffer of the tensor.
setData
public void setData(ByteBuffer data)
Sets the data buffer of the tensor.
getDims
public int[] getDims()
Returns the dimensions of the tensor.
setDims
public void setDims(int[] dims)
Sets the dimensions of the tensor.
getBytes
public int getBytes()
Returns the number of bytes of the tensor.
getName
public String getName()
Returns the name of the tensor.
getQuantization
public Quantization getQuantization()
Returns the quantization of the tensor.
setQuantization
public void setQuantization(Quantization quantization)
Sets the quantization of the tensor.
API Functions (Buffer)
Buffer is a helper class to wrap multiple types of memory buffers (e.g., byte buffer, image planes) into a single class. It is used to preprocess input images before running inference. It does not deep copy the data, so the data must be outlived the buffer.
Buffer(Tensor tensor)
public Buffer(Tensor tensor)
Creates a buffer from the given tensor. The tensor must be created by the engine.
Buffer(byte[] buffer, int width, int height, BufferFormat bufferFormat)
public Buffer(byte[] buffer, int width, int height, BufferFormat bufferFormat)
Creates a buffer from the given byte buffer, width, height, and format. We currently support RGB, RGBA, and GRAYSCALE formats.
Buffer(byte[][] yuvBytes, int width, int height, int yRowStride, int uvRowStride, int uvPixelStride, BufferFormat bufferFormat)
public Buffer(byte[][] yuvBytes, int width, int height, int yRowStride, int uvRowStride, int uvPixelStride, BufferFormat bufferFormat)
Creates a buffer from the given YUV byte buffer, width, height, row strides, and format. We currently support YUV formats (NV21, NV12, YV12, YV21).
Buffer(Image.Plane[] planes, int width, int height, BufferFormat format)
public Buffer(Image.Plane[] planes, int width, int height, BufferFormat format)
Creates a buffer from the given planes, width, height, and format. We currently support YUV formats (NV21, NV12, YV12, YV21).