diff options
author | Eric Messick <ericm@nanorex.com> | 2008-03-07 22:02:25 +0000 |
---|---|---|
committer | Eric Messick <ericm@nanorex.com> | 2008-03-07 22:02:25 +0000 |
commit | 66cfa1486a8385f8b18be305c215c6f954ab4e0f (patch) | |
tree | 2045c4b96f4bcd7d8594efb5eeb8697189b60dc2 | |
parent | a3b7d560ad0c29e7773cebd84c609502b0820ed6 (diff) | |
download | nanoengineer-66cfa1486a8385f8b18be305c215c6f954ab4e0f.tar.gz nanoengineer-66cfa1486a8385f8b18be305c215c6f954ab4e0f.zip |
samevals.c handles numeric arrays correctly
-rwxr-xr-x | cad/src/samevalshelp.c | 92 | ||||
-rw-r--r-- | cad/src/tests/samevalstests.py | 62 |
2 files changed, 129 insertions, 25 deletions
diff --git a/cad/src/samevalshelp.c b/cad/src/samevalshelp.c index f361001a8..f5357f91b 100755 --- a/cad/src/samevalshelp.c +++ b/cad/src/samevalshelp.c @@ -3,6 +3,7 @@ * Type "python setup2.py build_ext --inplace" to build. */ +#include <alloca.h> #include "Python.h" #include "Numeric/arrayobject.h" @@ -67,24 +68,85 @@ _same_vals_helper(PyObject *v1, PyObject *v2) } else if (arrayType != NULL && typ1 == arrayType) { PyArrayObject *x = (PyArrayObject *) v1; PyArrayObject *y = (PyArrayObject *) v2; - int i, elsize, howmany = 1; - if (x->nd != y->nd) return 1; - for (i = 0; i < x->nd; i++) { + int i; + int elementSize; + int *indices; + int topDimension; + char *xdata; + char *ydata; + int objectCompare = 0; + + // do all quick rejects first (no loops) + if (x->nd != y->nd) { + // number of dimensions doesn't match + return 1; + // note that a (1 x X) array can never equal a single + // dimensional array of length X. + } + if (x->descr->type_num != y->descr->type_num) { + // type of elements doesn't match + return 1; + } + if (x->descr->type_num == PyArray_OBJECT) { + objectCompare = 1; + } + elementSize = x->descr->elsize; + if (elementSize != y->descr->elsize) { + // size of elements doesn't match (shouldn't happen if + // types match!) + return 1; + } + for (i = x->nd - 1; i >= 0; i--) { if (x->dimensions[i] != y->dimensions[i]) + // shapes don't match return 1; - howmany *= x->dimensions[i]; - } - // if one stride is NULL and the other isn't, it's a problem - if ((x->strides == NULL) ^ (y->strides == NULL)) return 1; - if (x->strides != NULL) { - // both non-null, compare them - for (i = 0; i < x->nd; i++) - if (x->strides[i] != y->strides[i]) return 1; } - if (x->descr->type_num != y->descr->type_num) return 1; - elsize = x->descr->elsize; - if (elsize != y->descr->elsize) return 1; - if (memcmp(x->data, y->data, elsize * howmany) != 0) return 1; + // we do a lot of these, so handle them early + if (x->nd == 1 && !objectCompare && x->strides[0]==elementSize && y->strides[0]==elementSize) { + // contiguous one dimensional array of non-objects + return memcmp(x->data, y->data, elementSize * x->dimensions[0]) ? 1 : 0; + } + if (x->nd == 0) { + // scalar, just compare one element + if (objectCompare) { + return _same_vals_helper(*(PyObject **)x->data, *(PyObject **)y->data); + } else { + return memcmp(x->data, y->data, elementSize) ? 1 : 0; + } + } + // If we decide we can't do alloca() for some reason, we can + // either have a fixed maximum dimensionality, or use alloc + // and free. + indices = (int *)alloca(sizeof(int) * x->nd); + for (i = x->nd - 1; i >= 0; i--) { + indices[i] = 0; + } + topDimension = x->dimensions[0]; + while (indices[0] < topDimension) { + xdata = x->data; + ydata = y->data; + for (i = 0; i < x->nd; i++) { + xdata += indices[i] * x->strides[i]; + ydata += indices[i] * y->strides[i]; + } + if (objectCompare) { + if (_same_vals_helper(*(PyObject **)xdata, *(PyObject **)ydata)) { + return 1; + } + } else if (memcmp(xdata, ydata, elementSize) != 0) { + // element mismatch + return 1; + } + // step to next element + for (i = x->nd - 1; i>=0; i--) { + indices[i]++; + if (i == 0 || indices[i] < x->dimensions[i]) { + break; + } + indices[i] = 0; + } + } + // all elements match return 0; } #if 0 diff --git a/cad/src/tests/samevalstests.py b/cad/src/tests/samevalstests.py index 155f0afda..53240f49f 100644 --- a/cad/src/tests/samevalstests.py +++ b/cad/src/tests/samevalstests.py @@ -51,19 +51,61 @@ class SameValsTests(unittest.TestCase): assert same_vals((1, 2), (1, 2)) assert not same_vals((1, 2), (2, 1)) - def test_numeric_equals(self): + def test_numericArray1(self): a = Numeric.array((1, 2, 3)) b = Numeric.array((1, 2, 3)) - print a == b - print a != b - assert a == b - assert not a != b - b = Numeric.array((1, 4, 5)) - print a == b - print a != b - assert a != b - assert not a == b + assert same_vals(a, b) + b = Numeric.array((1, 2, 4)) + assert not same_vals(a, b) + b = Numeric.array((1, 2)) + assert not same_vals(a, b) + a = Numeric.array([[1, 2], [3, 4]]) + b = Numeric.array([[1, 2], [3, 4]]) + assert same_vals(a, b) + + b = Numeric.array([4, 3]) + c = a[1, 1::-1] + assert same_vals(b, c) + + a = Numeric.array([[[[ 1, 2, 3], [ 4, 5, 6], [ 7, 8, 9]], + [[10, 11, 12], [13, 14, 15], [16, 17, 18]], + [[19, 20, 21], [22, 23, 24], [25, 26, 27]]], + [[[28, 29, 30], [31, 32, 33], [34, 35, 36]], + [[37, 38, 39], [40, 41, 42], [43, 44, 45]], + [[46, 47, 48], [49, 50, 51], [52, 53, 54]]]]) + b = Numeric.array([[[[ 1, 2, 3], [ 4, 5, 6], [ 7, 8, 9]], + [[10, 11, 12], [13, 14, 15], [16, 17, 18]], + [[19, 20, 21], [22, 23, 24], [25, 26, 27]]], + [[[28, 29, 30], [31, 32, 33], [34, 35, 36]], + [[37, 38, 39], [40, 41, 42], [43, 44, 45]], + [[46, 47, 48], [49, 50, 51], [52, 53, 54]]]]) + assert same_vals(a, b) + b = Numeric.array([[[[ 1, 2, 3], [ 4, 5, 6], [ 7, 8, 9]], + [[10, 11, 12], [13, 14, 15], [16, 17, 18]], + [[19, 20, 21], [22, 23, 24], [25, 26, 27]]], + [[[28, 29, 30], [31, 32, 33], [34, 35, 36]], + [[37, 38, 39], [40, 41, 42], [43, 44, 45]], + [[46, 47, 48], [49, 50, 51], [52, 53, 55]]]]) + assert not same_vals(a, b) + b = Numeric.array([[[[ 1, 2, 3], [ 4, 5, 6], [ 7, 8, 9]], + [[10, 11, 12], [13, 14, 15], [16, 17, 18]], + [[19, 20, 21], [22, 23, 24], [25, 26, 27]]], + [[[28, 29, 30], [31, 30, 33], [34, 35, 36]], + [[37, 38, 39], [40, 41, 42], [43, 44, 45]], + [[46, 47, 48], [49, 50, 51], [52, 53, 54]]]]) + assert not same_vals(a, b) + b = Numeric.array([[[[ 1, 2, 3], [ 4, 5, 6], [ 7, 8, 9]], + [[10, 11, 12], [13, 14, 15], [16, 17, 18]], + [[19, 20, 21], [22, 23, 24], [25, 26, 27]]]]) + assert not same_vals(a, b) + + a = Numeric.array(["abc", "def"], Numeric.PyObject) + b = Numeric.array(["abc", "def"], Numeric.PyObject) + assert same_vals(a, b) + b = Numeric.array(["abc", "defg"], Numeric.PyObject) + assert not same_vals(a, b) + def test(): suite = unittest.makeSuite(SameValsTests, 'test') runner = unittest.TextTestRunner() |