Commit b331b9f29062a0e888dafc7a7d19ca8fa53286c1
1 parent
d5925187
Exists in
master
Improvements in brain segmenter gui
Showing
3 changed files
with
67 additions
and
26 deletions
Show diff stats
invesalius/gui/brain_seg_dialog.py
1 | 1 | #!/usr/bin/env python |
2 | 2 | # -*- coding: UTF-8 -*- |
3 | 3 | |
4 | +import importlib | |
4 | 5 | import os |
5 | 6 | import pathlib |
6 | 7 | import sys |
... | ... | @@ -8,13 +9,8 @@ import sys |
8 | 9 | import wx |
9 | 10 | from wx.lib.pubsub import pub as Publisher |
10 | 11 | |
11 | -HAS_THEANO = True | |
12 | -HAS_PLAIDML = True | |
13 | - | |
14 | -try: | |
15 | - import theano | |
16 | -except ImportError: | |
17 | - HAS_THEANO = False | |
12 | +HAS_THEANO = bool(importlib.util.find_spec("theano")) | |
13 | +HAS_PLAIDML = bool(importlib.util.find_spec("plaidml")) | |
18 | 14 | |
19 | 15 | # Linux if installed plaidml with pip3 install --user |
20 | 16 | if sys.platform.startswith("linux"): |
... | ... | @@ -29,13 +25,10 @@ elif sys.platform == "darwin": |
29 | 25 | os.environ["RUNFILES_DIR"] = str(local_user_plaidml) |
30 | 26 | os.environ["PLAIDML_NATIVE_PATH"] = str(pathlib.Path("/usr/local/lib/libplaidml.dylib").expanduser().absolute()) |
31 | 27 | |
32 | -try: | |
33 | - import plaidml | |
34 | -except ImportError: | |
35 | - HAS_PLAIDML = False | |
36 | 28 | |
37 | 29 | import invesalius.data.slice_ as slc |
38 | 30 | from invesalius.segmentation.brain import segment |
31 | +from invesalius.segmentation.brain import utils | |
39 | 32 | |
40 | 33 | |
41 | 34 | |
... | ... | @@ -49,11 +42,13 @@ class BrainSegmenterDialog(wx.Dialog): |
49 | 42 | backends.append("Theano") |
50 | 43 | self.segmenter = segment.BrainSegmenter() |
51 | 44 | # self.pg_dialog = None |
52 | - | |
45 | + self.plaidml_devices = utils.get_plaidml_devices() | |
53 | 46 | self.cb_backends = wx.ComboBox(self, wx.ID_ANY, choices=backends, value=backends[0], style=wx.CB_DROPDOWN | wx.CB_READONLY) |
54 | 47 | w, h = self.CalcSizeFromTextSize("MM" * (1 + max(len(i) for i in backends))) |
55 | 48 | self.cb_backends.SetMinClientSize((w, -1)) |
56 | 49 | self.chk_use_gpu = wx.CheckBox(self, wx.ID_ANY, _("Use GPU")) |
50 | + self.lbl_device = wx.StaticText(self, -1, _("Device")) | |
51 | + self.cb_devices = wx.ComboBox(self, wx.ID_ANY, choices=list(self.plaidml_devices.keys()), value=list(self.plaidml_devices.keys())[0],style=wx.CB_DROPDOWN | wx.CB_READONLY) | |
57 | 52 | self.sld_threshold = wx.Slider(self, wx.ID_ANY, 75, 0, 100) |
58 | 53 | w, h = self.CalcSizeFromTextSize("M" * 20) |
59 | 54 | self.sld_threshold.SetMinClientSize((w, -1)) |
... | ... | @@ -64,6 +59,7 @@ class BrainSegmenterDialog(wx.Dialog): |
64 | 59 | self.btn_segment = wx.Button(self, wx.ID_ANY, _("Segment")) |
65 | 60 | self.btn_stop = wx.Button(self, wx.ID_ANY, _("Stop")) |
66 | 61 | self.btn_stop.Disable() |
62 | + self.btn_close = wx.Button(self, wx.ID_CLOSE) | |
67 | 63 | |
68 | 64 | self.txt_threshold.SetValue("{:3d}%".format(self.sld_threshold.GetValue())) |
69 | 65 | |
... | ... | @@ -76,9 +72,13 @@ class BrainSegmenterDialog(wx.Dialog): |
76 | 72 | sizer_backends = wx.BoxSizer(wx.HORIZONTAL) |
77 | 73 | label_1 = wx.StaticText(self, wx.ID_ANY, _("Backend")) |
78 | 74 | sizer_backends.Add(label_1, 0, wx.ALIGN_CENTER, 0) |
79 | - sizer_backends.Add(self.cb_backends, 1, wx.LEFT, 5) | |
75 | + sizer_backends.Add(self.cb_backends, 1, wx.LEFT, 0) | |
80 | 76 | main_sizer.Add(sizer_backends, 0, wx.ALL | wx.EXPAND, 5) |
81 | 77 | main_sizer.Add(self.chk_use_gpu, 0, wx.ALL, 5) |
78 | + sizer_devices = wx.BoxSizer(wx.HORIZONTAL) | |
79 | + sizer_devices.Add(self.lbl_device, 0, wx.ALIGN_CENTER, 0) | |
80 | + sizer_devices.Add(self.cb_devices, 1, wx.LEFT, 5) | |
81 | + main_sizer.Add(sizer_devices, 0, wx.ALL | wx.EXPAND, 5) | |
82 | 82 | label_5 = wx.StaticText(self, wx.ID_ANY, _("Level of certainty")) |
83 | 83 | main_sizer.Add(label_5, 0, wx.ALL, 5) |
84 | 84 | sizer_3.Add(self.sld_threshold, 1, wx.ALIGN_CENTER | wx.BOTTOM | wx.EXPAND | wx.LEFT | wx.RIGHT, 5) |
... | ... | @@ -86,20 +86,28 @@ class BrainSegmenterDialog(wx.Dialog): |
86 | 86 | main_sizer.Add(sizer_3, 0, wx.EXPAND, 0) |
87 | 87 | main_sizer.Add(self.progress, 0, wx.EXPAND | wx.ALL, 5) |
88 | 88 | sizer_buttons = wx.BoxSizer(wx.HORIZONTAL) |
89 | + sizer_buttons.Add(self.btn_close, 0, wx.ALIGN_BOTTOM | wx.ALIGN_RIGHT | wx.ALL, 5) | |
89 | 90 | sizer_buttons.Add(self.btn_stop, 0, wx.ALIGN_BOTTOM | wx.ALIGN_RIGHT | wx.ALL, 5) |
90 | 91 | sizer_buttons.Add(self.btn_segment, 0, wx.ALIGN_BOTTOM | wx.ALIGN_RIGHT | wx.ALL, 5) |
91 | 92 | main_sizer.Add(sizer_buttons, 0, wx.ALIGN_BOTTOM | wx.ALIGN_RIGHT | wx.ALL, 0) |
92 | 93 | self.SetSizer(main_sizer) |
93 | 94 | main_sizer.Fit(self) |
94 | 95 | main_sizer.SetSizeHints(self) |
96 | + | |
97 | + self.main_sizer = main_sizer | |
98 | + | |
99 | + self.OnSetBackend() | |
100 | + | |
95 | 101 | self.Layout() |
96 | 102 | self.Centre() |
97 | 103 | |
98 | 104 | def __set_events(self): |
105 | + self.cb_backends.Bind(wx.EVT_COMBOBOX, self.OnSetBackend) | |
99 | 106 | self.sld_threshold.Bind(wx.EVT_SCROLL, self.OnScrollThreshold) |
100 | 107 | self.txt_threshold.Bind(wx.EVT_KILL_FOCUS, self.OnKillFocus) |
101 | 108 | self.btn_segment.Bind(wx.EVT_BUTTON, self.OnSegment) |
102 | 109 | self.btn_stop.Bind(wx.EVT_BUTTON, self.OnStop) |
110 | + self.btn_close.Bind(wx.EVT_BUTTON, self.OnBtnClose) | |
103 | 111 | self.Bind(wx.EVT_CLOSE, self.OnClose) |
104 | 112 | |
105 | 113 | def CalcSizeFromTextSize(self, text): |
... | ... | @@ -108,6 +116,19 @@ class BrainSegmenterDialog(wx.Dialog): |
108 | 116 | width, height = dc.GetTextExtent(text) |
109 | 117 | return width, height |
110 | 118 | |
119 | + def OnSetBackend(self, evt=None): | |
120 | + if self.cb_backends.GetValue().lower() == "plaidml": | |
121 | + self.lbl_device.Show() | |
122 | + self.cb_devices.Show() | |
123 | + self.chk_use_gpu.Hide() | |
124 | + else: | |
125 | + self.lbl_device.Hide() | |
126 | + self.cb_devices.Hide() | |
127 | + self.chk_use_gpu.Show() | |
128 | + | |
129 | + self.main_sizer.Fit(self) | |
130 | + self.main_sizer.SetSizeHints(self) | |
131 | + | |
111 | 132 | def OnScrollThreshold(self, evt): |
112 | 133 | value = self.sld_threshold.GetValue() |
113 | 134 | self.txt_threshold.SetValue("{:3d}%".format(self.sld_threshold.GetValue())) |
... | ... | @@ -136,14 +157,20 @@ class BrainSegmenterDialog(wx.Dialog): |
136 | 157 | def OnSegment(self, evt): |
137 | 158 | image = slc.Slice().matrix |
138 | 159 | backend = self.cb_backends.GetValue() |
160 | + try: | |
161 | + device_id = self.plaidml_devices[self.cb_devices.GetValue()] | |
162 | + except KeyError: | |
163 | + device_id = "llvm_cpu.0" | |
139 | 164 | use_gpu = self.chk_use_gpu.GetValue() |
140 | 165 | prob_threshold = self.sld_threshold.GetValue() / 100.0 |
141 | 166 | self.btn_stop.Enable() |
142 | 167 | self.btn_segment.Disable() |
168 | + | |
169 | + print(device_id) | |
143 | 170 | # self.pg_dialog = wx.ProgressDialog(_("Brain segmenter"), _("Segmenting brain"), parent=self, style= wx.FRAME_FLOAT_ON_PARENT | wx.PD_CAN_ABORT | wx.PD_AUTO_HIDE | wx.PD_ELAPSED_TIME) |
144 | 171 | # self.pg_dialog.Bind(wx.EVT_BUTTON, self.OnStop) |
145 | 172 | # self.pg_dialog.Show() |
146 | - self.segmenter.segment(image, prob_threshold, backend, use_gpu, self.SetProgress, self.AfterSegment) | |
173 | + self.segmenter.segment(image, prob_threshold, backend, device_id, use_gpu, self.SetProgress, self.AfterSegment) | |
147 | 174 | |
148 | 175 | def OnStop(self, evt): |
149 | 176 | self.segmenter.stop = True |
... | ... | @@ -153,7 +180,12 @@ class BrainSegmenterDialog(wx.Dialog): |
153 | 180 | self.btn_segment.Enable() |
154 | 181 | evt.Skip() |
155 | 182 | |
183 | + def OnBtnClose(self, evt): | |
184 | + self.Close() | |
185 | + | |
156 | 186 | def AfterSegment(self): |
187 | + self.btn_stop.Disable() | |
188 | + self.btn_segment.Disable() | |
157 | 189 | Publisher.sendMessage('Reload actual slice') |
158 | 190 | |
159 | 191 | def SetProgress(self, progress): | ... | ... |
invesalius/segmentation/brain/segment.py
... | ... | @@ -50,14 +50,18 @@ class BrainSegmenter: |
50 | 50 | self.stop = False |
51 | 51 | self.segmented = False |
52 | 52 | |
53 | - def segment(self, image, prob_threshold, backend, use_gpu, progress_callback=None, after_segment=None): | |
53 | + def segment(self, image, prob_threshold, backend, device_id, use_gpu, progress_callback=None, after_segment=None): | |
54 | 54 | print("backend", backend) |
55 | 55 | if backend.lower() == 'plaidml': |
56 | 56 | os.environ["KERAS_BACKEND"] = "plaidml.keras.backend" |
57 | - device = utils.get_plaidml_devices(use_gpu) | |
58 | - os.environ["PLAIDML_DEVICE_IDS"] = device.id.decode("utf8") | |
57 | + os.environ["PLAIDML_DEVICE_IDS"] = device_id | |
59 | 58 | elif backend.lower() == 'theano': |
60 | 59 | os.environ["KERAS_BACKEND"] = "theano" |
60 | + if use_gpu: | |
61 | + os.environ["THEANO_FLAGS"] = "device=cuda0" | |
62 | + print("Use GPU theano", os.environ["THEANO_FLAGS"]) | |
63 | + else: | |
64 | + os.environ["THEANO_FLAGS"] = "device=cpu" | |
61 | 65 | else: |
62 | 66 | raise TypeError("Wrong backend") |
63 | 67 | ... | ... |
invesalius/segmentation/brain/utils.py
... | ... | @@ -5,13 +5,18 @@ def get_plaidml_devices(gpu=False): |
5 | 5 | plaidml.settings._setup_for_test(plaidml.settings.user_settings) |
6 | 6 | plaidml.settings.experimental = True |
7 | 7 | devices, _ = plaidml.devices(ctx, limit=100, return_all=True) |
8 | - if gpu: | |
9 | - for device in devices: | |
10 | - if b"cuda" in device.description.lower(): | |
11 | - return device | |
12 | - for device in devices: | |
13 | - if b"opencl" in device.description.lower(): | |
14 | - return device | |
8 | + out_devices = [] | |
15 | 9 | for device in devices: |
16 | - if b"llvm" in device.description.lower(): | |
17 | - return device | |
10 | + points = 0 | |
11 | + if b"cuda" in device.description.lower(): | |
12 | + points += 1 | |
13 | + if b"opencl" in device.description.lower(): | |
14 | + points += 1 | |
15 | + if b"nvidia" in device.description.lower(): | |
16 | + points += 1 | |
17 | + if b"amd" in device.description.lower(): | |
18 | + points += 1 | |
19 | + out_devices.append((points, device)) | |
20 | + | |
21 | + out_devices.sort(reverse=True) | |
22 | + return {device.description.decode("utf8"): device.id.decode("utf8") for points, device in out_devices } | ... | ... |