Commit db65e15345734aa81f83de7acffcfae2ae57029f

Authored by Thiago Franco de Moraes
1 parent 3867177c
Exists in cython_threshold

Added a cython version to threshold

invesalius/data/slice_.py
... ... @@ -35,6 +35,8 @@ from mask import Mask
35 35 from project import Project
36 36 from data import mips
37 37  
  38 +from data import threshold
  39 +
38 40 OTHER=0
39 41 PLIST=1
40 42 WIDGET=2
... ... @@ -1206,12 +1208,14 @@ class Slice(object):
1206 1208 given slice_matrix.
1207 1209 """
1208 1210 thresh_min, thresh_max = self.current_mask.threshold_range
1209   - m = (((slice_matrix >= thresh_min) & (slice_matrix <= thresh_max)) * 255)
1210   - m[mask == 1] = 1
1211   - m[mask == 2] = 2
1212   - m[mask == 253] = 253
1213   - m[mask == 254] = 254
1214   - return m.astype('uint8')
  1211 + #m = (((slice_matrix >= thresh_min) & (slice_matrix <= thresh_max)) * 255)
  1212 + #m[mask == 1] = 1
  1213 + #m[mask == 2] = 2
  1214 + #m[mask == 253] = 253
  1215 + #m[mask == 254] = 254
  1216 + m = numpy.zeros_like(mask, dtype='uint8')
  1217 + threshold.threshold(slice_matrix, m, thresh_min, thresh_max)
  1218 + return m
1215 1219  
1216 1220 def do_threshold_to_all_slices(self):
1217 1221 mask = self.current_mask
... ...
invesalius/data/threshold.pyx 0 → 100644
... ... @@ -0,0 +1,32 @@
  1 +import numpy as np
  2 +cimport numpy as np
  3 +cimport cython
  4 +
  5 +from libc.math cimport floor, ceil, sqrt, fabs
  6 +from cython.parallel import prange
  7 +
  8 +DTYPE8 = np.uint8
  9 +ctypedef np.uint8_t DTYPE8_t
  10 +
  11 +DTYPE16 = np.int16
  12 +ctypedef np.int16_t DTYPE16_t
  13 +
  14 +DTYPEF32 = np.float32
  15 +ctypedef np.float32_t DTYPEF32_t
  16 +
  17 +@cython.boundscheck(False) # turn of bounds-checking for entire function
  18 +@cython.cdivision(True)
  19 +@cython.wraparound(False)
  20 +def threshold(DTYPE16_t[:, :] image, DTYPE8_t[:, :] mask, DTYPE16_t low, DTYPE16_t high):
  21 + cdef int sy = image.shape[0]
  22 + cdef int sx = image.shape[1]
  23 + cdef int x, y
  24 + cdef DTYPE16_t v
  25 + for y in prange(sy, nogil=True):
  26 + for x in xrange(sx):
  27 + v = mask[y, x]
  28 + if not v:
  29 + if v >= low and v <= high:
  30 + mask[y, x] = 255
  31 + else:
  32 + mask[y, x] = 0
... ...
setup.py
... ... @@ -12,7 +12,12 @@ if sys.platform == &#39;linux2&#39;:
12 12 ext_modules = [ Extension("invesalius.data.mips", ["invesalius/data/mips.pyx"],
13 13 include_dirs = [numpy.get_include()],
14 14 extra_compile_args=['-fopenmp'],
15   - extra_link_args=['-fopenmp'],)]
  15 + extra_link_args=['-fopenmp'],),
  16 + Extension("invesalius.data.threshold", ["invesalius/data/threshold.pyx"],
  17 + include_dirs = [numpy.get_include()],
  18 + extra_compile_args=['-fopenmp'],
  19 + extra_link_args=['-fopenmp'],),
  20 + ]
16 21 )
17 22  
18 23 elif sys.platform == 'win32':
... ...