Commit b331b9f29062a0e888dafc7a7d19ca8fa53286c1
1 parent
d5925187
Exists in
master
Improvements in brain segmenter gui
Showing
3 changed files
with
67 additions
and
26 deletions
Show diff stats
invesalius/gui/brain_seg_dialog.py
1 | #!/usr/bin/env python | 1 | #!/usr/bin/env python |
2 | # -*- coding: UTF-8 -*- | 2 | # -*- coding: UTF-8 -*- |
3 | 3 | ||
4 | +import importlib | ||
4 | import os | 5 | import os |
5 | import pathlib | 6 | import pathlib |
6 | import sys | 7 | import sys |
@@ -8,13 +9,8 @@ import sys | @@ -8,13 +9,8 @@ import sys | ||
8 | import wx | 9 | import wx |
9 | from wx.lib.pubsub import pub as Publisher | 10 | from wx.lib.pubsub import pub as Publisher |
10 | 11 | ||
11 | -HAS_THEANO = True | ||
12 | -HAS_PLAIDML = True | ||
13 | - | ||
14 | -try: | ||
15 | - import theano | ||
16 | -except ImportError: | ||
17 | - HAS_THEANO = False | 12 | +HAS_THEANO = bool(importlib.util.find_spec("theano")) |
13 | +HAS_PLAIDML = bool(importlib.util.find_spec("plaidml")) | ||
18 | 14 | ||
19 | # Linux if installed plaidml with pip3 install --user | 15 | # Linux if installed plaidml with pip3 install --user |
20 | if sys.platform.startswith("linux"): | 16 | if sys.platform.startswith("linux"): |
@@ -29,13 +25,10 @@ elif sys.platform == "darwin": | @@ -29,13 +25,10 @@ elif sys.platform == "darwin": | ||
29 | os.environ["RUNFILES_DIR"] = str(local_user_plaidml) | 25 | os.environ["RUNFILES_DIR"] = str(local_user_plaidml) |
30 | os.environ["PLAIDML_NATIVE_PATH"] = str(pathlib.Path("/usr/local/lib/libplaidml.dylib").expanduser().absolute()) | 26 | os.environ["PLAIDML_NATIVE_PATH"] = str(pathlib.Path("/usr/local/lib/libplaidml.dylib").expanduser().absolute()) |
31 | 27 | ||
32 | -try: | ||
33 | - import plaidml | ||
34 | -except ImportError: | ||
35 | - HAS_PLAIDML = False | ||
36 | 28 | ||
37 | import invesalius.data.slice_ as slc | 29 | import invesalius.data.slice_ as slc |
38 | from invesalius.segmentation.brain import segment | 30 | from invesalius.segmentation.brain import segment |
31 | +from invesalius.segmentation.brain import utils | ||
39 | 32 | ||
40 | 33 | ||
41 | 34 | ||
@@ -49,11 +42,13 @@ class BrainSegmenterDialog(wx.Dialog): | @@ -49,11 +42,13 @@ class BrainSegmenterDialog(wx.Dialog): | ||
49 | backends.append("Theano") | 42 | backends.append("Theano") |
50 | self.segmenter = segment.BrainSegmenter() | 43 | self.segmenter = segment.BrainSegmenter() |
51 | # self.pg_dialog = None | 44 | # self.pg_dialog = None |
52 | - | 45 | + self.plaidml_devices = utils.get_plaidml_devices() |
53 | self.cb_backends = wx.ComboBox(self, wx.ID_ANY, choices=backends, value=backends[0], style=wx.CB_DROPDOWN | wx.CB_READONLY) | 46 | self.cb_backends = wx.ComboBox(self, wx.ID_ANY, choices=backends, value=backends[0], style=wx.CB_DROPDOWN | wx.CB_READONLY) |
54 | w, h = self.CalcSizeFromTextSize("MM" * (1 + max(len(i) for i in backends))) | 47 | w, h = self.CalcSizeFromTextSize("MM" * (1 + max(len(i) for i in backends))) |
55 | self.cb_backends.SetMinClientSize((w, -1)) | 48 | self.cb_backends.SetMinClientSize((w, -1)) |
56 | self.chk_use_gpu = wx.CheckBox(self, wx.ID_ANY, _("Use GPU")) | 49 | self.chk_use_gpu = wx.CheckBox(self, wx.ID_ANY, _("Use GPU")) |
50 | + self.lbl_device = wx.StaticText(self, -1, _("Device")) | ||
51 | + self.cb_devices = wx.ComboBox(self, wx.ID_ANY, choices=list(self.plaidml_devices.keys()), value=list(self.plaidml_devices.keys())[0],style=wx.CB_DROPDOWN | wx.CB_READONLY) | ||
57 | self.sld_threshold = wx.Slider(self, wx.ID_ANY, 75, 0, 100) | 52 | self.sld_threshold = wx.Slider(self, wx.ID_ANY, 75, 0, 100) |
58 | w, h = self.CalcSizeFromTextSize("M" * 20) | 53 | w, h = self.CalcSizeFromTextSize("M" * 20) |
59 | self.sld_threshold.SetMinClientSize((w, -1)) | 54 | self.sld_threshold.SetMinClientSize((w, -1)) |
@@ -64,6 +59,7 @@ class BrainSegmenterDialog(wx.Dialog): | @@ -64,6 +59,7 @@ class BrainSegmenterDialog(wx.Dialog): | ||
64 | self.btn_segment = wx.Button(self, wx.ID_ANY, _("Segment")) | 59 | self.btn_segment = wx.Button(self, wx.ID_ANY, _("Segment")) |
65 | self.btn_stop = wx.Button(self, wx.ID_ANY, _("Stop")) | 60 | self.btn_stop = wx.Button(self, wx.ID_ANY, _("Stop")) |
66 | self.btn_stop.Disable() | 61 | self.btn_stop.Disable() |
62 | + self.btn_close = wx.Button(self, wx.ID_CLOSE) | ||
67 | 63 | ||
68 | self.txt_threshold.SetValue("{:3d}%".format(self.sld_threshold.GetValue())) | 64 | self.txt_threshold.SetValue("{:3d}%".format(self.sld_threshold.GetValue())) |
69 | 65 | ||
@@ -76,9 +72,13 @@ class BrainSegmenterDialog(wx.Dialog): | @@ -76,9 +72,13 @@ class BrainSegmenterDialog(wx.Dialog): | ||
76 | sizer_backends = wx.BoxSizer(wx.HORIZONTAL) | 72 | sizer_backends = wx.BoxSizer(wx.HORIZONTAL) |
77 | label_1 = wx.StaticText(self, wx.ID_ANY, _("Backend")) | 73 | label_1 = wx.StaticText(self, wx.ID_ANY, _("Backend")) |
78 | sizer_backends.Add(label_1, 0, wx.ALIGN_CENTER, 0) | 74 | sizer_backends.Add(label_1, 0, wx.ALIGN_CENTER, 0) |
79 | - sizer_backends.Add(self.cb_backends, 1, wx.LEFT, 5) | 75 | + sizer_backends.Add(self.cb_backends, 1, wx.LEFT, 0) |
80 | main_sizer.Add(sizer_backends, 0, wx.ALL | wx.EXPAND, 5) | 76 | main_sizer.Add(sizer_backends, 0, wx.ALL | wx.EXPAND, 5) |
81 | main_sizer.Add(self.chk_use_gpu, 0, wx.ALL, 5) | 77 | main_sizer.Add(self.chk_use_gpu, 0, wx.ALL, 5) |
78 | + sizer_devices = wx.BoxSizer(wx.HORIZONTAL) | ||
79 | + sizer_devices.Add(self.lbl_device, 0, wx.ALIGN_CENTER, 0) | ||
80 | + sizer_devices.Add(self.cb_devices, 1, wx.LEFT, 5) | ||
81 | + main_sizer.Add(sizer_devices, 0, wx.ALL | wx.EXPAND, 5) | ||
82 | label_5 = wx.StaticText(self, wx.ID_ANY, _("Level of certainty")) | 82 | label_5 = wx.StaticText(self, wx.ID_ANY, _("Level of certainty")) |
83 | main_sizer.Add(label_5, 0, wx.ALL, 5) | 83 | main_sizer.Add(label_5, 0, wx.ALL, 5) |
84 | sizer_3.Add(self.sld_threshold, 1, wx.ALIGN_CENTER | wx.BOTTOM | wx.EXPAND | wx.LEFT | wx.RIGHT, 5) | 84 | sizer_3.Add(self.sld_threshold, 1, wx.ALIGN_CENTER | wx.BOTTOM | wx.EXPAND | wx.LEFT | wx.RIGHT, 5) |
@@ -86,20 +86,28 @@ class BrainSegmenterDialog(wx.Dialog): | @@ -86,20 +86,28 @@ class BrainSegmenterDialog(wx.Dialog): | ||
86 | main_sizer.Add(sizer_3, 0, wx.EXPAND, 0) | 86 | main_sizer.Add(sizer_3, 0, wx.EXPAND, 0) |
87 | main_sizer.Add(self.progress, 0, wx.EXPAND | wx.ALL, 5) | 87 | main_sizer.Add(self.progress, 0, wx.EXPAND | wx.ALL, 5) |
88 | sizer_buttons = wx.BoxSizer(wx.HORIZONTAL) | 88 | sizer_buttons = wx.BoxSizer(wx.HORIZONTAL) |
89 | + sizer_buttons.Add(self.btn_close, 0, wx.ALIGN_BOTTOM | wx.ALIGN_RIGHT | wx.ALL, 5) | ||
89 | sizer_buttons.Add(self.btn_stop, 0, wx.ALIGN_BOTTOM | wx.ALIGN_RIGHT | wx.ALL, 5) | 90 | sizer_buttons.Add(self.btn_stop, 0, wx.ALIGN_BOTTOM | wx.ALIGN_RIGHT | wx.ALL, 5) |
90 | sizer_buttons.Add(self.btn_segment, 0, wx.ALIGN_BOTTOM | wx.ALIGN_RIGHT | wx.ALL, 5) | 91 | sizer_buttons.Add(self.btn_segment, 0, wx.ALIGN_BOTTOM | wx.ALIGN_RIGHT | wx.ALL, 5) |
91 | main_sizer.Add(sizer_buttons, 0, wx.ALIGN_BOTTOM | wx.ALIGN_RIGHT | wx.ALL, 0) | 92 | main_sizer.Add(sizer_buttons, 0, wx.ALIGN_BOTTOM | wx.ALIGN_RIGHT | wx.ALL, 0) |
92 | self.SetSizer(main_sizer) | 93 | self.SetSizer(main_sizer) |
93 | main_sizer.Fit(self) | 94 | main_sizer.Fit(self) |
94 | main_sizer.SetSizeHints(self) | 95 | main_sizer.SetSizeHints(self) |
96 | + | ||
97 | + self.main_sizer = main_sizer | ||
98 | + | ||
99 | + self.OnSetBackend() | ||
100 | + | ||
95 | self.Layout() | 101 | self.Layout() |
96 | self.Centre() | 102 | self.Centre() |
97 | 103 | ||
98 | def __set_events(self): | 104 | def __set_events(self): |
105 | + self.cb_backends.Bind(wx.EVT_COMBOBOX, self.OnSetBackend) | ||
99 | self.sld_threshold.Bind(wx.EVT_SCROLL, self.OnScrollThreshold) | 106 | self.sld_threshold.Bind(wx.EVT_SCROLL, self.OnScrollThreshold) |
100 | self.txt_threshold.Bind(wx.EVT_KILL_FOCUS, self.OnKillFocus) | 107 | self.txt_threshold.Bind(wx.EVT_KILL_FOCUS, self.OnKillFocus) |
101 | self.btn_segment.Bind(wx.EVT_BUTTON, self.OnSegment) | 108 | self.btn_segment.Bind(wx.EVT_BUTTON, self.OnSegment) |
102 | self.btn_stop.Bind(wx.EVT_BUTTON, self.OnStop) | 109 | self.btn_stop.Bind(wx.EVT_BUTTON, self.OnStop) |
110 | + self.btn_close.Bind(wx.EVT_BUTTON, self.OnBtnClose) | ||
103 | self.Bind(wx.EVT_CLOSE, self.OnClose) | 111 | self.Bind(wx.EVT_CLOSE, self.OnClose) |
104 | 112 | ||
105 | def CalcSizeFromTextSize(self, text): | 113 | def CalcSizeFromTextSize(self, text): |
@@ -108,6 +116,19 @@ class BrainSegmenterDialog(wx.Dialog): | @@ -108,6 +116,19 @@ class BrainSegmenterDialog(wx.Dialog): | ||
108 | width, height = dc.GetTextExtent(text) | 116 | width, height = dc.GetTextExtent(text) |
109 | return width, height | 117 | return width, height |
110 | 118 | ||
119 | + def OnSetBackend(self, evt=None): | ||
120 | + if self.cb_backends.GetValue().lower() == "plaidml": | ||
121 | + self.lbl_device.Show() | ||
122 | + self.cb_devices.Show() | ||
123 | + self.chk_use_gpu.Hide() | ||
124 | + else: | ||
125 | + self.lbl_device.Hide() | ||
126 | + self.cb_devices.Hide() | ||
127 | + self.chk_use_gpu.Show() | ||
128 | + | ||
129 | + self.main_sizer.Fit(self) | ||
130 | + self.main_sizer.SetSizeHints(self) | ||
131 | + | ||
111 | def OnScrollThreshold(self, evt): | 132 | def OnScrollThreshold(self, evt): |
112 | value = self.sld_threshold.GetValue() | 133 | value = self.sld_threshold.GetValue() |
113 | self.txt_threshold.SetValue("{:3d}%".format(self.sld_threshold.GetValue())) | 134 | self.txt_threshold.SetValue("{:3d}%".format(self.sld_threshold.GetValue())) |
@@ -136,14 +157,20 @@ class BrainSegmenterDialog(wx.Dialog): | @@ -136,14 +157,20 @@ class BrainSegmenterDialog(wx.Dialog): | ||
136 | def OnSegment(self, evt): | 157 | def OnSegment(self, evt): |
137 | image = slc.Slice().matrix | 158 | image = slc.Slice().matrix |
138 | backend = self.cb_backends.GetValue() | 159 | backend = self.cb_backends.GetValue() |
160 | + try: | ||
161 | + device_id = self.plaidml_devices[self.cb_devices.GetValue()] | ||
162 | + except KeyError: | ||
163 | + device_id = "llvm_cpu.0" | ||
139 | use_gpu = self.chk_use_gpu.GetValue() | 164 | use_gpu = self.chk_use_gpu.GetValue() |
140 | prob_threshold = self.sld_threshold.GetValue() / 100.0 | 165 | prob_threshold = self.sld_threshold.GetValue() / 100.0 |
141 | self.btn_stop.Enable() | 166 | self.btn_stop.Enable() |
142 | self.btn_segment.Disable() | 167 | self.btn_segment.Disable() |
168 | + | ||
169 | + print(device_id) | ||
143 | # self.pg_dialog = wx.ProgressDialog(_("Brain segmenter"), _("Segmenting brain"), parent=self, style= wx.FRAME_FLOAT_ON_PARENT | wx.PD_CAN_ABORT | wx.PD_AUTO_HIDE | wx.PD_ELAPSED_TIME) | 170 | # self.pg_dialog = wx.ProgressDialog(_("Brain segmenter"), _("Segmenting brain"), parent=self, style= wx.FRAME_FLOAT_ON_PARENT | wx.PD_CAN_ABORT | wx.PD_AUTO_HIDE | wx.PD_ELAPSED_TIME) |
144 | # self.pg_dialog.Bind(wx.EVT_BUTTON, self.OnStop) | 171 | # self.pg_dialog.Bind(wx.EVT_BUTTON, self.OnStop) |
145 | # self.pg_dialog.Show() | 172 | # self.pg_dialog.Show() |
146 | - self.segmenter.segment(image, prob_threshold, backend, use_gpu, self.SetProgress, self.AfterSegment) | 173 | + self.segmenter.segment(image, prob_threshold, backend, device_id, use_gpu, self.SetProgress, self.AfterSegment) |
147 | 174 | ||
148 | def OnStop(self, evt): | 175 | def OnStop(self, evt): |
149 | self.segmenter.stop = True | 176 | self.segmenter.stop = True |
@@ -153,7 +180,12 @@ class BrainSegmenterDialog(wx.Dialog): | @@ -153,7 +180,12 @@ class BrainSegmenterDialog(wx.Dialog): | ||
153 | self.btn_segment.Enable() | 180 | self.btn_segment.Enable() |
154 | evt.Skip() | 181 | evt.Skip() |
155 | 182 | ||
183 | + def OnBtnClose(self, evt): | ||
184 | + self.Close() | ||
185 | + | ||
156 | def AfterSegment(self): | 186 | def AfterSegment(self): |
187 | + self.btn_stop.Disable() | ||
188 | + self.btn_segment.Disable() | ||
157 | Publisher.sendMessage('Reload actual slice') | 189 | Publisher.sendMessage('Reload actual slice') |
158 | 190 | ||
159 | def SetProgress(self, progress): | 191 | def SetProgress(self, progress): |
invesalius/segmentation/brain/segment.py
@@ -50,14 +50,18 @@ class BrainSegmenter: | @@ -50,14 +50,18 @@ class BrainSegmenter: | ||
50 | self.stop = False | 50 | self.stop = False |
51 | self.segmented = False | 51 | self.segmented = False |
52 | 52 | ||
53 | - def segment(self, image, prob_threshold, backend, use_gpu, progress_callback=None, after_segment=None): | 53 | + def segment(self, image, prob_threshold, backend, device_id, use_gpu, progress_callback=None, after_segment=None): |
54 | print("backend", backend) | 54 | print("backend", backend) |
55 | if backend.lower() == 'plaidml': | 55 | if backend.lower() == 'plaidml': |
56 | os.environ["KERAS_BACKEND"] = "plaidml.keras.backend" | 56 | os.environ["KERAS_BACKEND"] = "plaidml.keras.backend" |
57 | - device = utils.get_plaidml_devices(use_gpu) | ||
58 | - os.environ["PLAIDML_DEVICE_IDS"] = device.id.decode("utf8") | 57 | + os.environ["PLAIDML_DEVICE_IDS"] = device_id |
59 | elif backend.lower() == 'theano': | 58 | elif backend.lower() == 'theano': |
60 | os.environ["KERAS_BACKEND"] = "theano" | 59 | os.environ["KERAS_BACKEND"] = "theano" |
60 | + if use_gpu: | ||
61 | + os.environ["THEANO_FLAGS"] = "device=cuda0" | ||
62 | + print("Use GPU theano", os.environ["THEANO_FLAGS"]) | ||
63 | + else: | ||
64 | + os.environ["THEANO_FLAGS"] = "device=cpu" | ||
61 | else: | 65 | else: |
62 | raise TypeError("Wrong backend") | 66 | raise TypeError("Wrong backend") |
63 | 67 |
invesalius/segmentation/brain/utils.py
@@ -5,13 +5,18 @@ def get_plaidml_devices(gpu=False): | @@ -5,13 +5,18 @@ def get_plaidml_devices(gpu=False): | ||
5 | plaidml.settings._setup_for_test(plaidml.settings.user_settings) | 5 | plaidml.settings._setup_for_test(plaidml.settings.user_settings) |
6 | plaidml.settings.experimental = True | 6 | plaidml.settings.experimental = True |
7 | devices, _ = plaidml.devices(ctx, limit=100, return_all=True) | 7 | devices, _ = plaidml.devices(ctx, limit=100, return_all=True) |
8 | - if gpu: | ||
9 | - for device in devices: | ||
10 | - if b"cuda" in device.description.lower(): | ||
11 | - return device | ||
12 | - for device in devices: | ||
13 | - if b"opencl" in device.description.lower(): | ||
14 | - return device | 8 | + out_devices = [] |
15 | for device in devices: | 9 | for device in devices: |
16 | - if b"llvm" in device.description.lower(): | ||
17 | - return device | 10 | + points = 0 |
11 | + if b"cuda" in device.description.lower(): | ||
12 | + points += 1 | ||
13 | + if b"opencl" in device.description.lower(): | ||
14 | + points += 1 | ||
15 | + if b"nvidia" in device.description.lower(): | ||
16 | + points += 1 | ||
17 | + if b"amd" in device.description.lower(): | ||
18 | + points += 1 | ||
19 | + out_devices.append((points, device)) | ||
20 | + | ||
21 | + out_devices.sort(reverse=True) | ||
22 | + return {device.description.decode("utf8"): device.id.decode("utf8") for points, device in out_devices } |