From db65e15345734aa81f83de7acffcfae2ae57029f Mon Sep 17 00:00:00 2001 From: Thiago Franco de Moraes Date: Tue, 24 Feb 2015 13:49:16 -0300 Subject: [PATCH] Added a cython version to threshold --- invesalius/data/slice_.py | 16 ++++++++++------ invesalius/data/threshold.pyx | 32 ++++++++++++++++++++++++++++++++ setup.py | 7 ++++++- 3 files changed, 48 insertions(+), 7 deletions(-) create mode 100644 invesalius/data/threshold.pyx diff --git a/invesalius/data/slice_.py b/invesalius/data/slice_.py index 7fb0a24..49ee223 100644 --- a/invesalius/data/slice_.py +++ b/invesalius/data/slice_.py @@ -35,6 +35,8 @@ from mask import Mask from project import Project from data import mips +from data import threshold + OTHER=0 PLIST=1 WIDGET=2 @@ -1206,12 +1208,14 @@ class Slice(object): given slice_matrix. """ thresh_min, thresh_max = self.current_mask.threshold_range - m = (((slice_matrix >= thresh_min) & (slice_matrix <= thresh_max)) * 255) - m[mask == 1] = 1 - m[mask == 2] = 2 - m[mask == 253] = 253 - m[mask == 254] = 254 - return m.astype('uint8') + #m = (((slice_matrix >= thresh_min) & (slice_matrix <= thresh_max)) * 255) + #m[mask == 1] = 1 + #m[mask == 2] = 2 + #m[mask == 253] = 253 + #m[mask == 254] = 254 + m = numpy.zeros_like(mask, dtype='uint8') + threshold.threshold(slice_matrix, m, thresh_min, thresh_max) + return m def do_threshold_to_all_slices(self): mask = self.current_mask diff --git a/invesalius/data/threshold.pyx b/invesalius/data/threshold.pyx new file mode 100644 index 0000000..e375f7b --- /dev/null +++ b/invesalius/data/threshold.pyx @@ -0,0 +1,32 @@ +import numpy as np +cimport numpy as np +cimport cython + +from libc.math cimport floor, ceil, sqrt, fabs +from cython.parallel import prange + +DTYPE8 = np.uint8 +ctypedef np.uint8_t DTYPE8_t + +DTYPE16 = np.int16 +ctypedef np.int16_t DTYPE16_t + +DTYPEF32 = np.float32 +ctypedef np.float32_t DTYPEF32_t + +@cython.boundscheck(False) # turn of bounds-checking for entire function +@cython.cdivision(True) +@cython.wraparound(False) +def threshold(DTYPE16_t[:, :] image, DTYPE8_t[:, :] mask, DTYPE16_t low, DTYPE16_t high): + cdef int sy = image.shape[0] + cdef int sx = image.shape[1] + cdef int x, y + cdef DTYPE16_t v + for y in prange(sy, nogil=True): + for x in xrange(sx): + v = mask[y, x] + if not v: + if v >= low and v <= high: + mask[y, x] = 255 + else: + mask[y, x] = 0 diff --git a/setup.py b/setup.py index 794452e..f601e2d 100644 --- a/setup.py +++ b/setup.py @@ -12,7 +12,12 @@ if sys.platform == 'linux2': ext_modules = [ Extension("invesalius.data.mips", ["invesalius/data/mips.pyx"], include_dirs = [numpy.get_include()], extra_compile_args=['-fopenmp'], - extra_link_args=['-fopenmp'],)] + extra_link_args=['-fopenmp'],), + Extension("invesalius.data.threshold", ["invesalius/data/threshold.pyx"], + include_dirs = [numpy.get_include()], + extra_compile_args=['-fopenmp'], + extra_link_args=['-fopenmp'],), + ] ) elif sys.platform == 'win32': -- libgit2 0.21.2