Commit b331b9f29062a0e888dafc7a7d19ca8fa53286c1

Authored by Thiago Franco de Moraes
1 parent d5925187
Exists in master

Improvements in brain segmenter gui

invesalius/gui/brain_seg_dialog.py
1 1 #!/usr/bin/env python
2 2 # -*- coding: UTF-8 -*-
3 3  
  4 +import importlib
4 5 import os
5 6 import pathlib
6 7 import sys
... ... @@ -8,13 +9,8 @@ import sys
8 9 import wx
9 10 from wx.lib.pubsub import pub as Publisher
10 11  
11   -HAS_THEANO = True
12   -HAS_PLAIDML = True
13   -
14   -try:
15   - import theano
16   -except ImportError:
17   - HAS_THEANO = False
  12 +HAS_THEANO = bool(importlib.util.find_spec("theano"))
  13 +HAS_PLAIDML = bool(importlib.util.find_spec("plaidml"))
18 14  
19 15 # Linux if installed plaidml with pip3 install --user
20 16 if sys.platform.startswith("linux"):
... ... @@ -29,13 +25,10 @@ elif sys.platform == "darwin":
29 25 os.environ["RUNFILES_DIR"] = str(local_user_plaidml)
30 26 os.environ["PLAIDML_NATIVE_PATH"] = str(pathlib.Path("/usr/local/lib/libplaidml.dylib").expanduser().absolute())
31 27  
32   -try:
33   - import plaidml
34   -except ImportError:
35   - HAS_PLAIDML = False
36 28  
37 29 import invesalius.data.slice_ as slc
38 30 from invesalius.segmentation.brain import segment
  31 +from invesalius.segmentation.brain import utils
39 32  
40 33  
41 34  
... ... @@ -49,11 +42,13 @@ class BrainSegmenterDialog(wx.Dialog):
49 42 backends.append("Theano")
50 43 self.segmenter = segment.BrainSegmenter()
51 44 # self.pg_dialog = None
52   -
  45 + self.plaidml_devices = utils.get_plaidml_devices()
53 46 self.cb_backends = wx.ComboBox(self, wx.ID_ANY, choices=backends, value=backends[0], style=wx.CB_DROPDOWN | wx.CB_READONLY)
54 47 w, h = self.CalcSizeFromTextSize("MM" * (1 + max(len(i) for i in backends)))
55 48 self.cb_backends.SetMinClientSize((w, -1))
56 49 self.chk_use_gpu = wx.CheckBox(self, wx.ID_ANY, _("Use GPU"))
  50 + self.lbl_device = wx.StaticText(self, -1, _("Device"))
  51 + self.cb_devices = wx.ComboBox(self, wx.ID_ANY, choices=list(self.plaidml_devices.keys()), value=list(self.plaidml_devices.keys())[0],style=wx.CB_DROPDOWN | wx.CB_READONLY)
57 52 self.sld_threshold = wx.Slider(self, wx.ID_ANY, 75, 0, 100)
58 53 w, h = self.CalcSizeFromTextSize("M" * 20)
59 54 self.sld_threshold.SetMinClientSize((w, -1))
... ... @@ -64,6 +59,7 @@ class BrainSegmenterDialog(wx.Dialog):
64 59 self.btn_segment = wx.Button(self, wx.ID_ANY, _("Segment"))
65 60 self.btn_stop = wx.Button(self, wx.ID_ANY, _("Stop"))
66 61 self.btn_stop.Disable()
  62 + self.btn_close = wx.Button(self, wx.ID_CLOSE)
67 63  
68 64 self.txt_threshold.SetValue("{:3d}%".format(self.sld_threshold.GetValue()))
69 65  
... ... @@ -76,9 +72,13 @@ class BrainSegmenterDialog(wx.Dialog):
76 72 sizer_backends = wx.BoxSizer(wx.HORIZONTAL)
77 73 label_1 = wx.StaticText(self, wx.ID_ANY, _("Backend"))
78 74 sizer_backends.Add(label_1, 0, wx.ALIGN_CENTER, 0)
79   - sizer_backends.Add(self.cb_backends, 1, wx.LEFT, 5)
  75 + sizer_backends.Add(self.cb_backends, 1, wx.LEFT, 0)
80 76 main_sizer.Add(sizer_backends, 0, wx.ALL | wx.EXPAND, 5)
81 77 main_sizer.Add(self.chk_use_gpu, 0, wx.ALL, 5)
  78 + sizer_devices = wx.BoxSizer(wx.HORIZONTAL)
  79 + sizer_devices.Add(self.lbl_device, 0, wx.ALIGN_CENTER, 0)
  80 + sizer_devices.Add(self.cb_devices, 1, wx.LEFT, 5)
  81 + main_sizer.Add(sizer_devices, 0, wx.ALL | wx.EXPAND, 5)
82 82 label_5 = wx.StaticText(self, wx.ID_ANY, _("Level of certainty"))
83 83 main_sizer.Add(label_5, 0, wx.ALL, 5)
84 84 sizer_3.Add(self.sld_threshold, 1, wx.ALIGN_CENTER | wx.BOTTOM | wx.EXPAND | wx.LEFT | wx.RIGHT, 5)
... ... @@ -86,20 +86,28 @@ class BrainSegmenterDialog(wx.Dialog):
86 86 main_sizer.Add(sizer_3, 0, wx.EXPAND, 0)
87 87 main_sizer.Add(self.progress, 0, wx.EXPAND | wx.ALL, 5)
88 88 sizer_buttons = wx.BoxSizer(wx.HORIZONTAL)
  89 + sizer_buttons.Add(self.btn_close, 0, wx.ALIGN_BOTTOM | wx.ALIGN_RIGHT | wx.ALL, 5)
89 90 sizer_buttons.Add(self.btn_stop, 0, wx.ALIGN_BOTTOM | wx.ALIGN_RIGHT | wx.ALL, 5)
90 91 sizer_buttons.Add(self.btn_segment, 0, wx.ALIGN_BOTTOM | wx.ALIGN_RIGHT | wx.ALL, 5)
91 92 main_sizer.Add(sizer_buttons, 0, wx.ALIGN_BOTTOM | wx.ALIGN_RIGHT | wx.ALL, 0)
92 93 self.SetSizer(main_sizer)
93 94 main_sizer.Fit(self)
94 95 main_sizer.SetSizeHints(self)
  96 +
  97 + self.main_sizer = main_sizer
  98 +
  99 + self.OnSetBackend()
  100 +
95 101 self.Layout()
96 102 self.Centre()
97 103  
98 104 def __set_events(self):
  105 + self.cb_backends.Bind(wx.EVT_COMBOBOX, self.OnSetBackend)
99 106 self.sld_threshold.Bind(wx.EVT_SCROLL, self.OnScrollThreshold)
100 107 self.txt_threshold.Bind(wx.EVT_KILL_FOCUS, self.OnKillFocus)
101 108 self.btn_segment.Bind(wx.EVT_BUTTON, self.OnSegment)
102 109 self.btn_stop.Bind(wx.EVT_BUTTON, self.OnStop)
  110 + self.btn_close.Bind(wx.EVT_BUTTON, self.OnBtnClose)
103 111 self.Bind(wx.EVT_CLOSE, self.OnClose)
104 112  
105 113 def CalcSizeFromTextSize(self, text):
... ... @@ -108,6 +116,19 @@ class BrainSegmenterDialog(wx.Dialog):
108 116 width, height = dc.GetTextExtent(text)
109 117 return width, height
110 118  
  119 + def OnSetBackend(self, evt=None):
  120 + if self.cb_backends.GetValue().lower() == "plaidml":
  121 + self.lbl_device.Show()
  122 + self.cb_devices.Show()
  123 + self.chk_use_gpu.Hide()
  124 + else:
  125 + self.lbl_device.Hide()
  126 + self.cb_devices.Hide()
  127 + self.chk_use_gpu.Show()
  128 +
  129 + self.main_sizer.Fit(self)
  130 + self.main_sizer.SetSizeHints(self)
  131 +
111 132 def OnScrollThreshold(self, evt):
112 133 value = self.sld_threshold.GetValue()
113 134 self.txt_threshold.SetValue("{:3d}%".format(self.sld_threshold.GetValue()))
... ... @@ -136,14 +157,20 @@ class BrainSegmenterDialog(wx.Dialog):
136 157 def OnSegment(self, evt):
137 158 image = slc.Slice().matrix
138 159 backend = self.cb_backends.GetValue()
  160 + try:
  161 + device_id = self.plaidml_devices[self.cb_devices.GetValue()]
  162 + except KeyError:
  163 + device_id = "llvm_cpu.0"
139 164 use_gpu = self.chk_use_gpu.GetValue()
140 165 prob_threshold = self.sld_threshold.GetValue() / 100.0
141 166 self.btn_stop.Enable()
142 167 self.btn_segment.Disable()
  168 +
  169 + print(device_id)
143 170 # self.pg_dialog = wx.ProgressDialog(_("Brain segmenter"), _("Segmenting brain"), parent=self, style= wx.FRAME_FLOAT_ON_PARENT | wx.PD_CAN_ABORT | wx.PD_AUTO_HIDE | wx.PD_ELAPSED_TIME)
144 171 # self.pg_dialog.Bind(wx.EVT_BUTTON, self.OnStop)
145 172 # self.pg_dialog.Show()
146   - self.segmenter.segment(image, prob_threshold, backend, use_gpu, self.SetProgress, self.AfterSegment)
  173 + self.segmenter.segment(image, prob_threshold, backend, device_id, use_gpu, self.SetProgress, self.AfterSegment)
147 174  
148 175 def OnStop(self, evt):
149 176 self.segmenter.stop = True
... ... @@ -153,7 +180,12 @@ class BrainSegmenterDialog(wx.Dialog):
153 180 self.btn_segment.Enable()
154 181 evt.Skip()
155 182  
  183 + def OnBtnClose(self, evt):
  184 + self.Close()
  185 +
156 186 def AfterSegment(self):
  187 + self.btn_stop.Disable()
  188 + self.btn_segment.Disable()
157 189 Publisher.sendMessage('Reload actual slice')
158 190  
159 191 def SetProgress(self, progress):
... ...
invesalius/segmentation/brain/segment.py
... ... @@ -50,14 +50,18 @@ class BrainSegmenter:
50 50 self.stop = False
51 51 self.segmented = False
52 52  
53   - def segment(self, image, prob_threshold, backend, use_gpu, progress_callback=None, after_segment=None):
  53 + def segment(self, image, prob_threshold, backend, device_id, use_gpu, progress_callback=None, after_segment=None):
54 54 print("backend", backend)
55 55 if backend.lower() == 'plaidml':
56 56 os.environ["KERAS_BACKEND"] = "plaidml.keras.backend"
57   - device = utils.get_plaidml_devices(use_gpu)
58   - os.environ["PLAIDML_DEVICE_IDS"] = device.id.decode("utf8")
  57 + os.environ["PLAIDML_DEVICE_IDS"] = device_id
59 58 elif backend.lower() == 'theano':
60 59 os.environ["KERAS_BACKEND"] = "theano"
  60 + if use_gpu:
  61 + os.environ["THEANO_FLAGS"] = "device=cuda0"
  62 + print("Use GPU theano", os.environ["THEANO_FLAGS"])
  63 + else:
  64 + os.environ["THEANO_FLAGS"] = "device=cpu"
61 65 else:
62 66 raise TypeError("Wrong backend")
63 67  
... ...
invesalius/segmentation/brain/utils.py
... ... @@ -5,13 +5,18 @@ def get_plaidml_devices(gpu=False):
5 5 plaidml.settings._setup_for_test(plaidml.settings.user_settings)
6 6 plaidml.settings.experimental = True
7 7 devices, _ = plaidml.devices(ctx, limit=100, return_all=True)
8   - if gpu:
9   - for device in devices:
10   - if b"cuda" in device.description.lower():
11   - return device
12   - for device in devices:
13   - if b"opencl" in device.description.lower():
14   - return device
  8 + out_devices = []
15 9 for device in devices:
16   - if b"llvm" in device.description.lower():
17   - return device
  10 + points = 0
  11 + if b"cuda" in device.description.lower():
  12 + points += 1
  13 + if b"opencl" in device.description.lower():
  14 + points += 1
  15 + if b"nvidia" in device.description.lower():
  16 + points += 1
  17 + if b"amd" in device.description.lower():
  18 + points += 1
  19 + out_devices.append((points, device))
  20 +
  21 + out_devices.sort(reverse=True)
  22 + return {device.description.decode("utf8"): device.id.decode("utf8") for points, device in out_devices }
... ...