diff --git a/gojni.h b/gojni.h index 2d0cb86..008dfd6 100644 --- a/gojni.h +++ b/gojni.h @@ -37,6 +37,9 @@ __attribute__ ((visibility ("hidden"))) jlong _jni_CallLongMethodA(JNIEnv *env, __attribute__ ((visibility ("hidden"))) jfloat _jni_CallFloatMethodA(JNIEnv *env, jobject obj, jmethodID method, jvalue *args); __attribute__ ((visibility ("hidden"))) jdouble _jni_CallDoubleMethodA(JNIEnv *env, jobject obj, jmethodID method, jvalue *args); __attribute__ ((visibility ("hidden"))) jbyteArray _jni_NewByteArray(JNIEnv *env, jsize length); +__attribute__ ((visibility ("hidden"))) jobjectArray _jni_NewObjectArray(JNIEnv *env, jsize length, jclass elementClass, jobject initialElement); +__attribute__ ((visibility ("hidden"))) jobject _jni_GetObjectArrayElement(JNIEnv *env, jobjectArray array, jsize index); +__attribute__ ((visibility ("hidden"))) void _jni_SetObjectArrayElement(JNIEnv *env, jobjectArray array, jsize index, jobject value); __attribute__ ((visibility ("hidden"))) jbyte *_jni_GetByteArrayElements(JNIEnv *env, jbyteArray arr); __attribute__ ((visibility ("hidden"))) void _jni_ReleaseByteArrayElements(JNIEnv *env, jbyteArray arr, jbyte *elems, jint mode); __attribute__ ((visibility ("hidden"))) jsize _jni_GetArrayLength(JNIEnv *env, jarray arr); diff --git a/jni.c b/jni.c index ef7ee95..2aae752 100644 --- a/jni.c +++ b/jni.c @@ -178,6 +178,18 @@ jsize _jni_GetArrayLength(JNIEnv *env, jarray arr) { return (*env)->GetArrayLength(env, arr); } +jobjectArray _jni_NewObjectArray(JNIEnv *env, jsize length, jclass elementClass, jobject initialElement) { + return (*env)->NewObjectArray(env, length, elementClass, initialElement); +} + +jobject _jni_GetObjectArrayElement(JNIEnv *env, jobjectArray array, jsize index) { + return (*env)->GetObjectArrayElement(env, array, index); +} + +void _jni_SetObjectArrayElement(JNIEnv *env, jobjectArray array, jsize index, jobject value) { + return (*env)->SetObjectArrayElement(env, array, index, value); +} + jobject _jni_GetStaticObjectField(JNIEnv *env, jclass clazz, jfieldID fieldID) { return (*env)->GetStaticObjectField(env, clazz, fieldID); } diff --git a/jni.go b/jni.go index 325ae1c..22ec8f4 100644 --- a/jni.go +++ b/jni.go @@ -33,13 +33,15 @@ type Env struct { } type ( - Class C.jclass - Object C.jobject - MethodID C.jmethodID - FieldID C.jfieldID - String C.jstring - ByteArray C.jbyteArray - Value uint64 // All JNI types fit into 64-bits. + Class C.jclass + Object C.jobject + MethodID C.jmethodID + FieldID C.jfieldID + String C.jstring + ByteArray C.jbyteArray + ObjectArray C.jobjectArray + Size C.jsize + Value uint64 // All JNI types fit into 64-bits. ) const ( @@ -221,6 +223,28 @@ func NewByteArray(e Env, content []byte) ByteArray { return ByteArray(jarr) } +func NewObjectArray(e Env, len Size, class Class, elem Object) ObjectArray { + jarr := C._jni_NewObjectArray(e.env, C.jsize(len), C.jclass(class), C.jobject(elem)) + if jarr == 0 { + panic(fmt.Errorf("jni: NewObjectArray failed")) + } + return ObjectArray(jarr) +} + +func GetObjectArrayElement(e Env, jarr ObjectArray, index Size) (Object, error) { + jobj := C._jni_GetObjectArrayElement(e.env, C.jobjectArray(jarr), C.jsize(index)) + return Object(jobj), exception(e) +} + +func SetObjectArrayElement(e Env, jarr ObjectArray, index Size, value Object) error { + C._jni_SetObjectArrayElement( + e.env, + C.jobjectArray(jarr), + C.jsize(index), + C.jobject(value)) + return exception(e) +} + // ClassLoader returns a reference to the Java ClassLoader associated // with obj. func ClassLoaderFor(e Env, obj Object) Object { diff --git a/jni_test.go b/jni_test.go index d9d2eb6..f5c6c00 100644 --- a/jni_test.go +++ b/jni_test.go @@ -152,7 +152,7 @@ func TestStaticMethod(t *testing.T) { }() Do(vm, func(env Env) error { cls := FindClass(env, "test/AClass") - GetStaticMethodID(env, cls, "noSuchMethod","()V") + GetStaticMethodID(env, cls, "noSuchMethod", "()V") return nil }) } @@ -208,7 +208,7 @@ func TestMethod(t *testing.T) { }() Do(vm, func(env Env) error { cls := FindClass(env, "test/AClass") - GetMethodID(env, cls, "noSuchMethod","()V") + GetMethodID(env, cls, "noSuchMethod", "()V") return nil }) } @@ -791,11 +791,11 @@ func TestStaticField2(t *testing.T) { } func TestByteArray(t *testing.T) { - arr := []byte{'a','b','c'} + arr := []byte{'a', 'b', 'c'} Do(vm, func(env Env) error { jarr := NewByteArray(env, arr) arr2 := GetByteArrayElements(env, jarr) - if arr2[0] != 'a' || + if arr2[0] != 'a' || arr2[1] != 'b' || arr2[2] != 'c' || len(arr2) != 3 { @@ -805,6 +805,42 @@ func TestByteArray(t *testing.T) { }) } +func TestObjectArray(t *testing.T) { + Do(vm, func(env Env) error { + cls := FindClass(env, "java/lang/Object") + if cls == 0 { + t.Errorf("Class is nil") + } + + strings := []string{ + "item1", + "item2", + } + + arr := NewObjectArray(env, Size(len(strings)), cls, 0) + for i, s := range strings { + jstring := JavaString(env, s) + err := SetObjectArrayElement(env, arr, Size(i), Object(jstring)) + if err != nil { + t.Errorf("SetObjectArrayElement at index %d raised exception", i) + } + } + + for i, s := range strings { + r, err := GetObjectArrayElement(env, arr, Size(i)) + if err != nil { + t.Errorf("GetObjectArrayElement at index %d raised exception", i) + } + + retrieved := GoString(env, String(r)) + if s != retrieved { + t.Errorf("Expected item %d to be '%s' but instead was '%s'", i, s, retrieved) + } + } + return nil + }) +} + func TestClassLoaderFor(t *testing.T) { defer func() { if r := recover(); r == nil {