summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEric Messick <ericm@nanorex.com>2008-03-07 22:02:25 +0000
committerEric Messick <ericm@nanorex.com>2008-03-07 22:02:25 +0000
commit66cfa1486a8385f8b18be305c215c6f954ab4e0f (patch)
tree2045c4b96f4bcd7d8594efb5eeb8697189b60dc2
parenta3b7d560ad0c29e7773cebd84c609502b0820ed6 (diff)
downloadnanoengineer-66cfa1486a8385f8b18be305c215c6f954ab4e0f.tar.gz
nanoengineer-66cfa1486a8385f8b18be305c215c6f954ab4e0f.zip
samevals.c handles numeric arrays correctly
-rwxr-xr-xcad/src/samevalshelp.c92
-rw-r--r--cad/src/tests/samevalstests.py62
2 files changed, 129 insertions, 25 deletions
diff --git a/cad/src/samevalshelp.c b/cad/src/samevalshelp.c
index f361001a8..f5357f91b 100755
--- a/cad/src/samevalshelp.c
+++ b/cad/src/samevalshelp.c
@@ -3,6 +3,7 @@
* Type "python setup2.py build_ext --inplace" to build.
*/
+#include <alloca.h>
#include "Python.h"
#include "Numeric/arrayobject.h"
@@ -67,24 +68,85 @@ _same_vals_helper(PyObject *v1, PyObject *v2)
} else if (arrayType != NULL && typ1 == arrayType) {
PyArrayObject *x = (PyArrayObject *) v1;
PyArrayObject *y = (PyArrayObject *) v2;
- int i, elsize, howmany = 1;
- if (x->nd != y->nd) return 1;
- for (i = 0; i < x->nd; i++) {
+ int i;
+ int elementSize;
+ int *indices;
+ int topDimension;
+ char *xdata;
+ char *ydata;
+ int objectCompare = 0;
+
+ // do all quick rejects first (no loops)
+ if (x->nd != y->nd) {
+ // number of dimensions doesn't match
+ return 1;
+ // note that a (1 x X) array can never equal a single
+ // dimensional array of length X.
+ }
+ if (x->descr->type_num != y->descr->type_num) {
+ // type of elements doesn't match
+ return 1;
+ }
+ if (x->descr->type_num == PyArray_OBJECT) {
+ objectCompare = 1;
+ }
+ elementSize = x->descr->elsize;
+ if (elementSize != y->descr->elsize) {
+ // size of elements doesn't match (shouldn't happen if
+ // types match!)
+ return 1;
+ }
+ for (i = x->nd - 1; i >= 0; i--) {
if (x->dimensions[i] != y->dimensions[i])
+ // shapes don't match
return 1;
- howmany *= x->dimensions[i];
- }
- // if one stride is NULL and the other isn't, it's a problem
- if ((x->strides == NULL) ^ (y->strides == NULL)) return 1;
- if (x->strides != NULL) {
- // both non-null, compare them
- for (i = 0; i < x->nd; i++)
- if (x->strides[i] != y->strides[i]) return 1;
}
- if (x->descr->type_num != y->descr->type_num) return 1;
- elsize = x->descr->elsize;
- if (elsize != y->descr->elsize) return 1;
- if (memcmp(x->data, y->data, elsize * howmany) != 0) return 1;
+ // we do a lot of these, so handle them early
+ if (x->nd == 1 && !objectCompare && x->strides[0]==elementSize && y->strides[0]==elementSize) {
+ // contiguous one dimensional array of non-objects
+ return memcmp(x->data, y->data, elementSize * x->dimensions[0]) ? 1 : 0;
+ }
+ if (x->nd == 0) {
+ // scalar, just compare one element
+ if (objectCompare) {
+ return _same_vals_helper(*(PyObject **)x->data, *(PyObject **)y->data);
+ } else {
+ return memcmp(x->data, y->data, elementSize) ? 1 : 0;
+ }
+ }
+ // If we decide we can't do alloca() for some reason, we can
+ // either have a fixed maximum dimensionality, or use alloc
+ // and free.
+ indices = (int *)alloca(sizeof(int) * x->nd);
+ for (i = x->nd - 1; i >= 0; i--) {
+ indices[i] = 0;
+ }
+ topDimension = x->dimensions[0];
+ while (indices[0] < topDimension) {
+ xdata = x->data;
+ ydata = y->data;
+ for (i = 0; i < x->nd; i++) {
+ xdata += indices[i] * x->strides[i];
+ ydata += indices[i] * y->strides[i];
+ }
+ if (objectCompare) {
+ if (_same_vals_helper(*(PyObject **)xdata, *(PyObject **)ydata)) {
+ return 1;
+ }
+ } else if (memcmp(xdata, ydata, elementSize) != 0) {
+ // element mismatch
+ return 1;
+ }
+ // step to next element
+ for (i = x->nd - 1; i>=0; i--) {
+ indices[i]++;
+ if (i == 0 || indices[i] < x->dimensions[i]) {
+ break;
+ }
+ indices[i] = 0;
+ }
+ }
+ // all elements match
return 0;
}
#if 0
diff --git a/cad/src/tests/samevalstests.py b/cad/src/tests/samevalstests.py
index 155f0afda..53240f49f 100644
--- a/cad/src/tests/samevalstests.py
+++ b/cad/src/tests/samevalstests.py
@@ -51,19 +51,61 @@ class SameValsTests(unittest.TestCase):
assert same_vals((1, 2), (1, 2))
assert not same_vals((1, 2), (2, 1))
- def test_numeric_equals(self):
+ def test_numericArray1(self):
a = Numeric.array((1, 2, 3))
b = Numeric.array((1, 2, 3))
- print a == b
- print a != b
- assert a == b
- assert not a != b
- b = Numeric.array((1, 4, 5))
- print a == b
- print a != b
- assert a != b
- assert not a == b
+ assert same_vals(a, b)
+ b = Numeric.array((1, 2, 4))
+ assert not same_vals(a, b)
+ b = Numeric.array((1, 2))
+ assert not same_vals(a, b)
+ a = Numeric.array([[1, 2], [3, 4]])
+ b = Numeric.array([[1, 2], [3, 4]])
+ assert same_vals(a, b)
+
+ b = Numeric.array([4, 3])
+ c = a[1, 1::-1]
+ assert same_vals(b, c)
+
+ a = Numeric.array([[[[ 1, 2, 3], [ 4, 5, 6], [ 7, 8, 9]],
+ [[10, 11, 12], [13, 14, 15], [16, 17, 18]],
+ [[19, 20, 21], [22, 23, 24], [25, 26, 27]]],
+ [[[28, 29, 30], [31, 32, 33], [34, 35, 36]],
+ [[37, 38, 39], [40, 41, 42], [43, 44, 45]],
+ [[46, 47, 48], [49, 50, 51], [52, 53, 54]]]])
+ b = Numeric.array([[[[ 1, 2, 3], [ 4, 5, 6], [ 7, 8, 9]],
+ [[10, 11, 12], [13, 14, 15], [16, 17, 18]],
+ [[19, 20, 21], [22, 23, 24], [25, 26, 27]]],
+ [[[28, 29, 30], [31, 32, 33], [34, 35, 36]],
+ [[37, 38, 39], [40, 41, 42], [43, 44, 45]],
+ [[46, 47, 48], [49, 50, 51], [52, 53, 54]]]])
+ assert same_vals(a, b)
+ b = Numeric.array([[[[ 1, 2, 3], [ 4, 5, 6], [ 7, 8, 9]],
+ [[10, 11, 12], [13, 14, 15], [16, 17, 18]],
+ [[19, 20, 21], [22, 23, 24], [25, 26, 27]]],
+ [[[28, 29, 30], [31, 32, 33], [34, 35, 36]],
+ [[37, 38, 39], [40, 41, 42], [43, 44, 45]],
+ [[46, 47, 48], [49, 50, 51], [52, 53, 55]]]])
+ assert not same_vals(a, b)
+ b = Numeric.array([[[[ 1, 2, 3], [ 4, 5, 6], [ 7, 8, 9]],
+ [[10, 11, 12], [13, 14, 15], [16, 17, 18]],
+ [[19, 20, 21], [22, 23, 24], [25, 26, 27]]],
+ [[[28, 29, 30], [31, 30, 33], [34, 35, 36]],
+ [[37, 38, 39], [40, 41, 42], [43, 44, 45]],
+ [[46, 47, 48], [49, 50, 51], [52, 53, 54]]]])
+ assert not same_vals(a, b)
+ b = Numeric.array([[[[ 1, 2, 3], [ 4, 5, 6], [ 7, 8, 9]],
+ [[10, 11, 12], [13, 14, 15], [16, 17, 18]],
+ [[19, 20, 21], [22, 23, 24], [25, 26, 27]]]])
+ assert not same_vals(a, b)
+
+ a = Numeric.array(["abc", "def"], Numeric.PyObject)
+ b = Numeric.array(["abc", "def"], Numeric.PyObject)
+ assert same_vals(a, b)
+ b = Numeric.array(["abc", "defg"], Numeric.PyObject)
+ assert not same_vals(a, b)
+
def test():
suite = unittest.makeSuite(SameValsTests, 'test')
runner = unittest.TextTestRunner()