Commit b331b9f29062a0e888dafc7a7d19ca8fa53286c1
1 parent
d5925187
Exists in
master
Improvements in brain segmenter gui
Showing
3 changed files
with
67 additions
and
26 deletions
Show diff stats
invesalius/gui/brain_seg_dialog.py
| 1 | #!/usr/bin/env python | 1 | #!/usr/bin/env python |
| 2 | # -*- coding: UTF-8 -*- | 2 | # -*- coding: UTF-8 -*- |
| 3 | 3 | ||
| 4 | +import importlib | ||
| 4 | import os | 5 | import os |
| 5 | import pathlib | 6 | import pathlib |
| 6 | import sys | 7 | import sys |
| @@ -8,13 +9,8 @@ import sys | @@ -8,13 +9,8 @@ import sys | ||
| 8 | import wx | 9 | import wx |
| 9 | from wx.lib.pubsub import pub as Publisher | 10 | from wx.lib.pubsub import pub as Publisher |
| 10 | 11 | ||
| 11 | -HAS_THEANO = True | ||
| 12 | -HAS_PLAIDML = True | ||
| 13 | - | ||
| 14 | -try: | ||
| 15 | - import theano | ||
| 16 | -except ImportError: | ||
| 17 | - HAS_THEANO = False | 12 | +HAS_THEANO = bool(importlib.util.find_spec("theano")) |
| 13 | +HAS_PLAIDML = bool(importlib.util.find_spec("plaidml")) | ||
| 18 | 14 | ||
| 19 | # Linux if installed plaidml with pip3 install --user | 15 | # Linux if installed plaidml with pip3 install --user |
| 20 | if sys.platform.startswith("linux"): | 16 | if sys.platform.startswith("linux"): |
| @@ -29,13 +25,10 @@ elif sys.platform == "darwin": | @@ -29,13 +25,10 @@ elif sys.platform == "darwin": | ||
| 29 | os.environ["RUNFILES_DIR"] = str(local_user_plaidml) | 25 | os.environ["RUNFILES_DIR"] = str(local_user_plaidml) |
| 30 | os.environ["PLAIDML_NATIVE_PATH"] = str(pathlib.Path("/usr/local/lib/libplaidml.dylib").expanduser().absolute()) | 26 | os.environ["PLAIDML_NATIVE_PATH"] = str(pathlib.Path("/usr/local/lib/libplaidml.dylib").expanduser().absolute()) |
| 31 | 27 | ||
| 32 | -try: | ||
| 33 | - import plaidml | ||
| 34 | -except ImportError: | ||
| 35 | - HAS_PLAIDML = False | ||
| 36 | 28 | ||
| 37 | import invesalius.data.slice_ as slc | 29 | import invesalius.data.slice_ as slc |
| 38 | from invesalius.segmentation.brain import segment | 30 | from invesalius.segmentation.brain import segment |
| 31 | +from invesalius.segmentation.brain import utils | ||
| 39 | 32 | ||
| 40 | 33 | ||
| 41 | 34 | ||
| @@ -49,11 +42,13 @@ class BrainSegmenterDialog(wx.Dialog): | @@ -49,11 +42,13 @@ class BrainSegmenterDialog(wx.Dialog): | ||
| 49 | backends.append("Theano") | 42 | backends.append("Theano") |
| 50 | self.segmenter = segment.BrainSegmenter() | 43 | self.segmenter = segment.BrainSegmenter() |
| 51 | # self.pg_dialog = None | 44 | # self.pg_dialog = None |
| 52 | - | 45 | + self.plaidml_devices = utils.get_plaidml_devices() |
| 53 | self.cb_backends = wx.ComboBox(self, wx.ID_ANY, choices=backends, value=backends[0], style=wx.CB_DROPDOWN | wx.CB_READONLY) | 46 | self.cb_backends = wx.ComboBox(self, wx.ID_ANY, choices=backends, value=backends[0], style=wx.CB_DROPDOWN | wx.CB_READONLY) |
| 54 | w, h = self.CalcSizeFromTextSize("MM" * (1 + max(len(i) for i in backends))) | 47 | w, h = self.CalcSizeFromTextSize("MM" * (1 + max(len(i) for i in backends))) |
| 55 | self.cb_backends.SetMinClientSize((w, -1)) | 48 | self.cb_backends.SetMinClientSize((w, -1)) |
| 56 | self.chk_use_gpu = wx.CheckBox(self, wx.ID_ANY, _("Use GPU")) | 49 | self.chk_use_gpu = wx.CheckBox(self, wx.ID_ANY, _("Use GPU")) |
| 50 | + self.lbl_device = wx.StaticText(self, -1, _("Device")) | ||
| 51 | + self.cb_devices = wx.ComboBox(self, wx.ID_ANY, choices=list(self.plaidml_devices.keys()), value=list(self.plaidml_devices.keys())[0],style=wx.CB_DROPDOWN | wx.CB_READONLY) | ||
| 57 | self.sld_threshold = wx.Slider(self, wx.ID_ANY, 75, 0, 100) | 52 | self.sld_threshold = wx.Slider(self, wx.ID_ANY, 75, 0, 100) |
| 58 | w, h = self.CalcSizeFromTextSize("M" * 20) | 53 | w, h = self.CalcSizeFromTextSize("M" * 20) |
| 59 | self.sld_threshold.SetMinClientSize((w, -1)) | 54 | self.sld_threshold.SetMinClientSize((w, -1)) |
| @@ -64,6 +59,7 @@ class BrainSegmenterDialog(wx.Dialog): | @@ -64,6 +59,7 @@ class BrainSegmenterDialog(wx.Dialog): | ||
| 64 | self.btn_segment = wx.Button(self, wx.ID_ANY, _("Segment")) | 59 | self.btn_segment = wx.Button(self, wx.ID_ANY, _("Segment")) |
| 65 | self.btn_stop = wx.Button(self, wx.ID_ANY, _("Stop")) | 60 | self.btn_stop = wx.Button(self, wx.ID_ANY, _("Stop")) |
| 66 | self.btn_stop.Disable() | 61 | self.btn_stop.Disable() |
| 62 | + self.btn_close = wx.Button(self, wx.ID_CLOSE) | ||
| 67 | 63 | ||
| 68 | self.txt_threshold.SetValue("{:3d}%".format(self.sld_threshold.GetValue())) | 64 | self.txt_threshold.SetValue("{:3d}%".format(self.sld_threshold.GetValue())) |
| 69 | 65 | ||
| @@ -76,9 +72,13 @@ class BrainSegmenterDialog(wx.Dialog): | @@ -76,9 +72,13 @@ class BrainSegmenterDialog(wx.Dialog): | ||
| 76 | sizer_backends = wx.BoxSizer(wx.HORIZONTAL) | 72 | sizer_backends = wx.BoxSizer(wx.HORIZONTAL) |
| 77 | label_1 = wx.StaticText(self, wx.ID_ANY, _("Backend")) | 73 | label_1 = wx.StaticText(self, wx.ID_ANY, _("Backend")) |
| 78 | sizer_backends.Add(label_1, 0, wx.ALIGN_CENTER, 0) | 74 | sizer_backends.Add(label_1, 0, wx.ALIGN_CENTER, 0) |
| 79 | - sizer_backends.Add(self.cb_backends, 1, wx.LEFT, 5) | 75 | + sizer_backends.Add(self.cb_backends, 1, wx.LEFT, 0) |
| 80 | main_sizer.Add(sizer_backends, 0, wx.ALL | wx.EXPAND, 5) | 76 | main_sizer.Add(sizer_backends, 0, wx.ALL | wx.EXPAND, 5) |
| 81 | main_sizer.Add(self.chk_use_gpu, 0, wx.ALL, 5) | 77 | main_sizer.Add(self.chk_use_gpu, 0, wx.ALL, 5) |
| 78 | + sizer_devices = wx.BoxSizer(wx.HORIZONTAL) | ||
| 79 | + sizer_devices.Add(self.lbl_device, 0, wx.ALIGN_CENTER, 0) | ||
| 80 | + sizer_devices.Add(self.cb_devices, 1, wx.LEFT, 5) | ||
| 81 | + main_sizer.Add(sizer_devices, 0, wx.ALL | wx.EXPAND, 5) | ||
| 82 | label_5 = wx.StaticText(self, wx.ID_ANY, _("Level of certainty")) | 82 | label_5 = wx.StaticText(self, wx.ID_ANY, _("Level of certainty")) |
| 83 | main_sizer.Add(label_5, 0, wx.ALL, 5) | 83 | main_sizer.Add(label_5, 0, wx.ALL, 5) |
| 84 | sizer_3.Add(self.sld_threshold, 1, wx.ALIGN_CENTER | wx.BOTTOM | wx.EXPAND | wx.LEFT | wx.RIGHT, 5) | 84 | sizer_3.Add(self.sld_threshold, 1, wx.ALIGN_CENTER | wx.BOTTOM | wx.EXPAND | wx.LEFT | wx.RIGHT, 5) |
| @@ -86,20 +86,28 @@ class BrainSegmenterDialog(wx.Dialog): | @@ -86,20 +86,28 @@ class BrainSegmenterDialog(wx.Dialog): | ||
| 86 | main_sizer.Add(sizer_3, 0, wx.EXPAND, 0) | 86 | main_sizer.Add(sizer_3, 0, wx.EXPAND, 0) |
| 87 | main_sizer.Add(self.progress, 0, wx.EXPAND | wx.ALL, 5) | 87 | main_sizer.Add(self.progress, 0, wx.EXPAND | wx.ALL, 5) |
| 88 | sizer_buttons = wx.BoxSizer(wx.HORIZONTAL) | 88 | sizer_buttons = wx.BoxSizer(wx.HORIZONTAL) |
| 89 | + sizer_buttons.Add(self.btn_close, 0, wx.ALIGN_BOTTOM | wx.ALIGN_RIGHT | wx.ALL, 5) | ||
| 89 | sizer_buttons.Add(self.btn_stop, 0, wx.ALIGN_BOTTOM | wx.ALIGN_RIGHT | wx.ALL, 5) | 90 | sizer_buttons.Add(self.btn_stop, 0, wx.ALIGN_BOTTOM | wx.ALIGN_RIGHT | wx.ALL, 5) |
| 90 | sizer_buttons.Add(self.btn_segment, 0, wx.ALIGN_BOTTOM | wx.ALIGN_RIGHT | wx.ALL, 5) | 91 | sizer_buttons.Add(self.btn_segment, 0, wx.ALIGN_BOTTOM | wx.ALIGN_RIGHT | wx.ALL, 5) |
| 91 | main_sizer.Add(sizer_buttons, 0, wx.ALIGN_BOTTOM | wx.ALIGN_RIGHT | wx.ALL, 0) | 92 | main_sizer.Add(sizer_buttons, 0, wx.ALIGN_BOTTOM | wx.ALIGN_RIGHT | wx.ALL, 0) |
| 92 | self.SetSizer(main_sizer) | 93 | self.SetSizer(main_sizer) |
| 93 | main_sizer.Fit(self) | 94 | main_sizer.Fit(self) |
| 94 | main_sizer.SetSizeHints(self) | 95 | main_sizer.SetSizeHints(self) |
| 96 | + | ||
| 97 | + self.main_sizer = main_sizer | ||
| 98 | + | ||
| 99 | + self.OnSetBackend() | ||
| 100 | + | ||
| 95 | self.Layout() | 101 | self.Layout() |
| 96 | self.Centre() | 102 | self.Centre() |
| 97 | 103 | ||
| 98 | def __set_events(self): | 104 | def __set_events(self): |
| 105 | + self.cb_backends.Bind(wx.EVT_COMBOBOX, self.OnSetBackend) | ||
| 99 | self.sld_threshold.Bind(wx.EVT_SCROLL, self.OnScrollThreshold) | 106 | self.sld_threshold.Bind(wx.EVT_SCROLL, self.OnScrollThreshold) |
| 100 | self.txt_threshold.Bind(wx.EVT_KILL_FOCUS, self.OnKillFocus) | 107 | self.txt_threshold.Bind(wx.EVT_KILL_FOCUS, self.OnKillFocus) |
| 101 | self.btn_segment.Bind(wx.EVT_BUTTON, self.OnSegment) | 108 | self.btn_segment.Bind(wx.EVT_BUTTON, self.OnSegment) |
| 102 | self.btn_stop.Bind(wx.EVT_BUTTON, self.OnStop) | 109 | self.btn_stop.Bind(wx.EVT_BUTTON, self.OnStop) |
| 110 | + self.btn_close.Bind(wx.EVT_BUTTON, self.OnBtnClose) | ||
| 103 | self.Bind(wx.EVT_CLOSE, self.OnClose) | 111 | self.Bind(wx.EVT_CLOSE, self.OnClose) |
| 104 | 112 | ||
| 105 | def CalcSizeFromTextSize(self, text): | 113 | def CalcSizeFromTextSize(self, text): |
| @@ -108,6 +116,19 @@ class BrainSegmenterDialog(wx.Dialog): | @@ -108,6 +116,19 @@ class BrainSegmenterDialog(wx.Dialog): | ||
| 108 | width, height = dc.GetTextExtent(text) | 116 | width, height = dc.GetTextExtent(text) |
| 109 | return width, height | 117 | return width, height |
| 110 | 118 | ||
| 119 | + def OnSetBackend(self, evt=None): | ||
| 120 | + if self.cb_backends.GetValue().lower() == "plaidml": | ||
| 121 | + self.lbl_device.Show() | ||
| 122 | + self.cb_devices.Show() | ||
| 123 | + self.chk_use_gpu.Hide() | ||
| 124 | + else: | ||
| 125 | + self.lbl_device.Hide() | ||
| 126 | + self.cb_devices.Hide() | ||
| 127 | + self.chk_use_gpu.Show() | ||
| 128 | + | ||
| 129 | + self.main_sizer.Fit(self) | ||
| 130 | + self.main_sizer.SetSizeHints(self) | ||
| 131 | + | ||
| 111 | def OnScrollThreshold(self, evt): | 132 | def OnScrollThreshold(self, evt): |
| 112 | value = self.sld_threshold.GetValue() | 133 | value = self.sld_threshold.GetValue() |
| 113 | self.txt_threshold.SetValue("{:3d}%".format(self.sld_threshold.GetValue())) | 134 | self.txt_threshold.SetValue("{:3d}%".format(self.sld_threshold.GetValue())) |
| @@ -136,14 +157,20 @@ class BrainSegmenterDialog(wx.Dialog): | @@ -136,14 +157,20 @@ class BrainSegmenterDialog(wx.Dialog): | ||
| 136 | def OnSegment(self, evt): | 157 | def OnSegment(self, evt): |
| 137 | image = slc.Slice().matrix | 158 | image = slc.Slice().matrix |
| 138 | backend = self.cb_backends.GetValue() | 159 | backend = self.cb_backends.GetValue() |
| 160 | + try: | ||
| 161 | + device_id = self.plaidml_devices[self.cb_devices.GetValue()] | ||
| 162 | + except KeyError: | ||
| 163 | + device_id = "llvm_cpu.0" | ||
| 139 | use_gpu = self.chk_use_gpu.GetValue() | 164 | use_gpu = self.chk_use_gpu.GetValue() |
| 140 | prob_threshold = self.sld_threshold.GetValue() / 100.0 | 165 | prob_threshold = self.sld_threshold.GetValue() / 100.0 |
| 141 | self.btn_stop.Enable() | 166 | self.btn_stop.Enable() |
| 142 | self.btn_segment.Disable() | 167 | self.btn_segment.Disable() |
| 168 | + | ||
| 169 | + print(device_id) | ||
| 143 | # self.pg_dialog = wx.ProgressDialog(_("Brain segmenter"), _("Segmenting brain"), parent=self, style= wx.FRAME_FLOAT_ON_PARENT | wx.PD_CAN_ABORT | wx.PD_AUTO_HIDE | wx.PD_ELAPSED_TIME) | 170 | # self.pg_dialog = wx.ProgressDialog(_("Brain segmenter"), _("Segmenting brain"), parent=self, style= wx.FRAME_FLOAT_ON_PARENT | wx.PD_CAN_ABORT | wx.PD_AUTO_HIDE | wx.PD_ELAPSED_TIME) |
| 144 | # self.pg_dialog.Bind(wx.EVT_BUTTON, self.OnStop) | 171 | # self.pg_dialog.Bind(wx.EVT_BUTTON, self.OnStop) |
| 145 | # self.pg_dialog.Show() | 172 | # self.pg_dialog.Show() |
| 146 | - self.segmenter.segment(image, prob_threshold, backend, use_gpu, self.SetProgress, self.AfterSegment) | 173 | + self.segmenter.segment(image, prob_threshold, backend, device_id, use_gpu, self.SetProgress, self.AfterSegment) |
| 147 | 174 | ||
| 148 | def OnStop(self, evt): | 175 | def OnStop(self, evt): |
| 149 | self.segmenter.stop = True | 176 | self.segmenter.stop = True |
| @@ -153,7 +180,12 @@ class BrainSegmenterDialog(wx.Dialog): | @@ -153,7 +180,12 @@ class BrainSegmenterDialog(wx.Dialog): | ||
| 153 | self.btn_segment.Enable() | 180 | self.btn_segment.Enable() |
| 154 | evt.Skip() | 181 | evt.Skip() |
| 155 | 182 | ||
| 183 | + def OnBtnClose(self, evt): | ||
| 184 | + self.Close() | ||
| 185 | + | ||
| 156 | def AfterSegment(self): | 186 | def AfterSegment(self): |
| 187 | + self.btn_stop.Disable() | ||
| 188 | + self.btn_segment.Disable() | ||
| 157 | Publisher.sendMessage('Reload actual slice') | 189 | Publisher.sendMessage('Reload actual slice') |
| 158 | 190 | ||
| 159 | def SetProgress(self, progress): | 191 | def SetProgress(self, progress): |
invesalius/segmentation/brain/segment.py
| @@ -50,14 +50,18 @@ class BrainSegmenter: | @@ -50,14 +50,18 @@ class BrainSegmenter: | ||
| 50 | self.stop = False | 50 | self.stop = False |
| 51 | self.segmented = False | 51 | self.segmented = False |
| 52 | 52 | ||
| 53 | - def segment(self, image, prob_threshold, backend, use_gpu, progress_callback=None, after_segment=None): | 53 | + def segment(self, image, prob_threshold, backend, device_id, use_gpu, progress_callback=None, after_segment=None): |
| 54 | print("backend", backend) | 54 | print("backend", backend) |
| 55 | if backend.lower() == 'plaidml': | 55 | if backend.lower() == 'plaidml': |
| 56 | os.environ["KERAS_BACKEND"] = "plaidml.keras.backend" | 56 | os.environ["KERAS_BACKEND"] = "plaidml.keras.backend" |
| 57 | - device = utils.get_plaidml_devices(use_gpu) | ||
| 58 | - os.environ["PLAIDML_DEVICE_IDS"] = device.id.decode("utf8") | 57 | + os.environ["PLAIDML_DEVICE_IDS"] = device_id |
| 59 | elif backend.lower() == 'theano': | 58 | elif backend.lower() == 'theano': |
| 60 | os.environ["KERAS_BACKEND"] = "theano" | 59 | os.environ["KERAS_BACKEND"] = "theano" |
| 60 | + if use_gpu: | ||
| 61 | + os.environ["THEANO_FLAGS"] = "device=cuda0" | ||
| 62 | + print("Use GPU theano", os.environ["THEANO_FLAGS"]) | ||
| 63 | + else: | ||
| 64 | + os.environ["THEANO_FLAGS"] = "device=cpu" | ||
| 61 | else: | 65 | else: |
| 62 | raise TypeError("Wrong backend") | 66 | raise TypeError("Wrong backend") |
| 63 | 67 |
invesalius/segmentation/brain/utils.py
| @@ -5,13 +5,18 @@ def get_plaidml_devices(gpu=False): | @@ -5,13 +5,18 @@ def get_plaidml_devices(gpu=False): | ||
| 5 | plaidml.settings._setup_for_test(plaidml.settings.user_settings) | 5 | plaidml.settings._setup_for_test(plaidml.settings.user_settings) |
| 6 | plaidml.settings.experimental = True | 6 | plaidml.settings.experimental = True |
| 7 | devices, _ = plaidml.devices(ctx, limit=100, return_all=True) | 7 | devices, _ = plaidml.devices(ctx, limit=100, return_all=True) |
| 8 | - if gpu: | ||
| 9 | - for device in devices: | ||
| 10 | - if b"cuda" in device.description.lower(): | ||
| 11 | - return device | ||
| 12 | - for device in devices: | ||
| 13 | - if b"opencl" in device.description.lower(): | ||
| 14 | - return device | 8 | + out_devices = [] |
| 15 | for device in devices: | 9 | for device in devices: |
| 16 | - if b"llvm" in device.description.lower(): | ||
| 17 | - return device | 10 | + points = 0 |
| 11 | + if b"cuda" in device.description.lower(): | ||
| 12 | + points += 1 | ||
| 13 | + if b"opencl" in device.description.lower(): | ||
| 14 | + points += 1 | ||
| 15 | + if b"nvidia" in device.description.lower(): | ||
| 16 | + points += 1 | ||
| 17 | + if b"amd" in device.description.lower(): | ||
| 18 | + points += 1 | ||
| 19 | + out_devices.append((points, device)) | ||
| 20 | + | ||
| 21 | + out_devices.sort(reverse=True) | ||
| 22 | + return {device.description.decode("utf8"): device.id.decode("utf8") for points, device in out_devices } |