Commit b331b9f29062a0e888dafc7a7d19ca8fa53286c1

Authored by Thiago Franco de Moraes
1 parent d5925187
Exists in master

Improvements in brain segmenter gui

invesalius/gui/brain_seg_dialog.py
1 #!/usr/bin/env python 1 #!/usr/bin/env python
2 # -*- coding: UTF-8 -*- 2 # -*- coding: UTF-8 -*-
3 3
  4 +import importlib
4 import os 5 import os
5 import pathlib 6 import pathlib
6 import sys 7 import sys
@@ -8,13 +9,8 @@ import sys @@ -8,13 +9,8 @@ import sys
8 import wx 9 import wx
9 from wx.lib.pubsub import pub as Publisher 10 from wx.lib.pubsub import pub as Publisher
10 11
11 -HAS_THEANO = True  
12 -HAS_PLAIDML = True  
13 -  
14 -try:  
15 - import theano  
16 -except ImportError:  
17 - HAS_THEANO = False 12 +HAS_THEANO = bool(importlib.util.find_spec("theano"))
  13 +HAS_PLAIDML = bool(importlib.util.find_spec("plaidml"))
18 14
19 # Linux if installed plaidml with pip3 install --user 15 # Linux if installed plaidml with pip3 install --user
20 if sys.platform.startswith("linux"): 16 if sys.platform.startswith("linux"):
@@ -29,13 +25,10 @@ elif sys.platform == "darwin": @@ -29,13 +25,10 @@ elif sys.platform == "darwin":
29 os.environ["RUNFILES_DIR"] = str(local_user_plaidml) 25 os.environ["RUNFILES_DIR"] = str(local_user_plaidml)
30 os.environ["PLAIDML_NATIVE_PATH"] = str(pathlib.Path("/usr/local/lib/libplaidml.dylib").expanduser().absolute()) 26 os.environ["PLAIDML_NATIVE_PATH"] = str(pathlib.Path("/usr/local/lib/libplaidml.dylib").expanduser().absolute())
31 27
32 -try:  
33 - import plaidml  
34 -except ImportError:  
35 - HAS_PLAIDML = False  
36 28
37 import invesalius.data.slice_ as slc 29 import invesalius.data.slice_ as slc
38 from invesalius.segmentation.brain import segment 30 from invesalius.segmentation.brain import segment
  31 +from invesalius.segmentation.brain import utils
39 32
40 33
41 34
@@ -49,11 +42,13 @@ class BrainSegmenterDialog(wx.Dialog): @@ -49,11 +42,13 @@ class BrainSegmenterDialog(wx.Dialog):
49 backends.append("Theano") 42 backends.append("Theano")
50 self.segmenter = segment.BrainSegmenter() 43 self.segmenter = segment.BrainSegmenter()
51 # self.pg_dialog = None 44 # self.pg_dialog = None
52 - 45 + self.plaidml_devices = utils.get_plaidml_devices()
53 self.cb_backends = wx.ComboBox(self, wx.ID_ANY, choices=backends, value=backends[0], style=wx.CB_DROPDOWN | wx.CB_READONLY) 46 self.cb_backends = wx.ComboBox(self, wx.ID_ANY, choices=backends, value=backends[0], style=wx.CB_DROPDOWN | wx.CB_READONLY)
54 w, h = self.CalcSizeFromTextSize("MM" * (1 + max(len(i) for i in backends))) 47 w, h = self.CalcSizeFromTextSize("MM" * (1 + max(len(i) for i in backends)))
55 self.cb_backends.SetMinClientSize((w, -1)) 48 self.cb_backends.SetMinClientSize((w, -1))
56 self.chk_use_gpu = wx.CheckBox(self, wx.ID_ANY, _("Use GPU")) 49 self.chk_use_gpu = wx.CheckBox(self, wx.ID_ANY, _("Use GPU"))
  50 + self.lbl_device = wx.StaticText(self, -1, _("Device"))
  51 + self.cb_devices = wx.ComboBox(self, wx.ID_ANY, choices=list(self.plaidml_devices.keys()), value=list(self.plaidml_devices.keys())[0],style=wx.CB_DROPDOWN | wx.CB_READONLY)
57 self.sld_threshold = wx.Slider(self, wx.ID_ANY, 75, 0, 100) 52 self.sld_threshold = wx.Slider(self, wx.ID_ANY, 75, 0, 100)
58 w, h = self.CalcSizeFromTextSize("M" * 20) 53 w, h = self.CalcSizeFromTextSize("M" * 20)
59 self.sld_threshold.SetMinClientSize((w, -1)) 54 self.sld_threshold.SetMinClientSize((w, -1))
@@ -64,6 +59,7 @@ class BrainSegmenterDialog(wx.Dialog): @@ -64,6 +59,7 @@ class BrainSegmenterDialog(wx.Dialog):
64 self.btn_segment = wx.Button(self, wx.ID_ANY, _("Segment")) 59 self.btn_segment = wx.Button(self, wx.ID_ANY, _("Segment"))
65 self.btn_stop = wx.Button(self, wx.ID_ANY, _("Stop")) 60 self.btn_stop = wx.Button(self, wx.ID_ANY, _("Stop"))
66 self.btn_stop.Disable() 61 self.btn_stop.Disable()
  62 + self.btn_close = wx.Button(self, wx.ID_CLOSE)
67 63
68 self.txt_threshold.SetValue("{:3d}%".format(self.sld_threshold.GetValue())) 64 self.txt_threshold.SetValue("{:3d}%".format(self.sld_threshold.GetValue()))
69 65
@@ -76,9 +72,13 @@ class BrainSegmenterDialog(wx.Dialog): @@ -76,9 +72,13 @@ class BrainSegmenterDialog(wx.Dialog):
76 sizer_backends = wx.BoxSizer(wx.HORIZONTAL) 72 sizer_backends = wx.BoxSizer(wx.HORIZONTAL)
77 label_1 = wx.StaticText(self, wx.ID_ANY, _("Backend")) 73 label_1 = wx.StaticText(self, wx.ID_ANY, _("Backend"))
78 sizer_backends.Add(label_1, 0, wx.ALIGN_CENTER, 0) 74 sizer_backends.Add(label_1, 0, wx.ALIGN_CENTER, 0)
79 - sizer_backends.Add(self.cb_backends, 1, wx.LEFT, 5) 75 + sizer_backends.Add(self.cb_backends, 1, wx.LEFT, 0)
80 main_sizer.Add(sizer_backends, 0, wx.ALL | wx.EXPAND, 5) 76 main_sizer.Add(sizer_backends, 0, wx.ALL | wx.EXPAND, 5)
81 main_sizer.Add(self.chk_use_gpu, 0, wx.ALL, 5) 77 main_sizer.Add(self.chk_use_gpu, 0, wx.ALL, 5)
  78 + sizer_devices = wx.BoxSizer(wx.HORIZONTAL)
  79 + sizer_devices.Add(self.lbl_device, 0, wx.ALIGN_CENTER, 0)
  80 + sizer_devices.Add(self.cb_devices, 1, wx.LEFT, 5)
  81 + main_sizer.Add(sizer_devices, 0, wx.ALL | wx.EXPAND, 5)
82 label_5 = wx.StaticText(self, wx.ID_ANY, _("Level of certainty")) 82 label_5 = wx.StaticText(self, wx.ID_ANY, _("Level of certainty"))
83 main_sizer.Add(label_5, 0, wx.ALL, 5) 83 main_sizer.Add(label_5, 0, wx.ALL, 5)
84 sizer_3.Add(self.sld_threshold, 1, wx.ALIGN_CENTER | wx.BOTTOM | wx.EXPAND | wx.LEFT | wx.RIGHT, 5) 84 sizer_3.Add(self.sld_threshold, 1, wx.ALIGN_CENTER | wx.BOTTOM | wx.EXPAND | wx.LEFT | wx.RIGHT, 5)
@@ -86,20 +86,28 @@ class BrainSegmenterDialog(wx.Dialog): @@ -86,20 +86,28 @@ class BrainSegmenterDialog(wx.Dialog):
86 main_sizer.Add(sizer_3, 0, wx.EXPAND, 0) 86 main_sizer.Add(sizer_3, 0, wx.EXPAND, 0)
87 main_sizer.Add(self.progress, 0, wx.EXPAND | wx.ALL, 5) 87 main_sizer.Add(self.progress, 0, wx.EXPAND | wx.ALL, 5)
88 sizer_buttons = wx.BoxSizer(wx.HORIZONTAL) 88 sizer_buttons = wx.BoxSizer(wx.HORIZONTAL)
  89 + sizer_buttons.Add(self.btn_close, 0, wx.ALIGN_BOTTOM | wx.ALIGN_RIGHT | wx.ALL, 5)
89 sizer_buttons.Add(self.btn_stop, 0, wx.ALIGN_BOTTOM | wx.ALIGN_RIGHT | wx.ALL, 5) 90 sizer_buttons.Add(self.btn_stop, 0, wx.ALIGN_BOTTOM | wx.ALIGN_RIGHT | wx.ALL, 5)
90 sizer_buttons.Add(self.btn_segment, 0, wx.ALIGN_BOTTOM | wx.ALIGN_RIGHT | wx.ALL, 5) 91 sizer_buttons.Add(self.btn_segment, 0, wx.ALIGN_BOTTOM | wx.ALIGN_RIGHT | wx.ALL, 5)
91 main_sizer.Add(sizer_buttons, 0, wx.ALIGN_BOTTOM | wx.ALIGN_RIGHT | wx.ALL, 0) 92 main_sizer.Add(sizer_buttons, 0, wx.ALIGN_BOTTOM | wx.ALIGN_RIGHT | wx.ALL, 0)
92 self.SetSizer(main_sizer) 93 self.SetSizer(main_sizer)
93 main_sizer.Fit(self) 94 main_sizer.Fit(self)
94 main_sizer.SetSizeHints(self) 95 main_sizer.SetSizeHints(self)
  96 +
  97 + self.main_sizer = main_sizer
  98 +
  99 + self.OnSetBackend()
  100 +
95 self.Layout() 101 self.Layout()
96 self.Centre() 102 self.Centre()
97 103
98 def __set_events(self): 104 def __set_events(self):
  105 + self.cb_backends.Bind(wx.EVT_COMBOBOX, self.OnSetBackend)
99 self.sld_threshold.Bind(wx.EVT_SCROLL, self.OnScrollThreshold) 106 self.sld_threshold.Bind(wx.EVT_SCROLL, self.OnScrollThreshold)
100 self.txt_threshold.Bind(wx.EVT_KILL_FOCUS, self.OnKillFocus) 107 self.txt_threshold.Bind(wx.EVT_KILL_FOCUS, self.OnKillFocus)
101 self.btn_segment.Bind(wx.EVT_BUTTON, self.OnSegment) 108 self.btn_segment.Bind(wx.EVT_BUTTON, self.OnSegment)
102 self.btn_stop.Bind(wx.EVT_BUTTON, self.OnStop) 109 self.btn_stop.Bind(wx.EVT_BUTTON, self.OnStop)
  110 + self.btn_close.Bind(wx.EVT_BUTTON, self.OnBtnClose)
103 self.Bind(wx.EVT_CLOSE, self.OnClose) 111 self.Bind(wx.EVT_CLOSE, self.OnClose)
104 112
105 def CalcSizeFromTextSize(self, text): 113 def CalcSizeFromTextSize(self, text):
@@ -108,6 +116,19 @@ class BrainSegmenterDialog(wx.Dialog): @@ -108,6 +116,19 @@ class BrainSegmenterDialog(wx.Dialog):
108 width, height = dc.GetTextExtent(text) 116 width, height = dc.GetTextExtent(text)
109 return width, height 117 return width, height
110 118
  119 + def OnSetBackend(self, evt=None):
  120 + if self.cb_backends.GetValue().lower() == "plaidml":
  121 + self.lbl_device.Show()
  122 + self.cb_devices.Show()
  123 + self.chk_use_gpu.Hide()
  124 + else:
  125 + self.lbl_device.Hide()
  126 + self.cb_devices.Hide()
  127 + self.chk_use_gpu.Show()
  128 +
  129 + self.main_sizer.Fit(self)
  130 + self.main_sizer.SetSizeHints(self)
  131 +
111 def OnScrollThreshold(self, evt): 132 def OnScrollThreshold(self, evt):
112 value = self.sld_threshold.GetValue() 133 value = self.sld_threshold.GetValue()
113 self.txt_threshold.SetValue("{:3d}%".format(self.sld_threshold.GetValue())) 134 self.txt_threshold.SetValue("{:3d}%".format(self.sld_threshold.GetValue()))
@@ -136,14 +157,20 @@ class BrainSegmenterDialog(wx.Dialog): @@ -136,14 +157,20 @@ class BrainSegmenterDialog(wx.Dialog):
136 def OnSegment(self, evt): 157 def OnSegment(self, evt):
137 image = slc.Slice().matrix 158 image = slc.Slice().matrix
138 backend = self.cb_backends.GetValue() 159 backend = self.cb_backends.GetValue()
  160 + try:
  161 + device_id = self.plaidml_devices[self.cb_devices.GetValue()]
  162 + except KeyError:
  163 + device_id = "llvm_cpu.0"
139 use_gpu = self.chk_use_gpu.GetValue() 164 use_gpu = self.chk_use_gpu.GetValue()
140 prob_threshold = self.sld_threshold.GetValue() / 100.0 165 prob_threshold = self.sld_threshold.GetValue() / 100.0
141 self.btn_stop.Enable() 166 self.btn_stop.Enable()
142 self.btn_segment.Disable() 167 self.btn_segment.Disable()
  168 +
  169 + print(device_id)
143 # self.pg_dialog = wx.ProgressDialog(_("Brain segmenter"), _("Segmenting brain"), parent=self, style= wx.FRAME_FLOAT_ON_PARENT | wx.PD_CAN_ABORT | wx.PD_AUTO_HIDE | wx.PD_ELAPSED_TIME) 170 # self.pg_dialog = wx.ProgressDialog(_("Brain segmenter"), _("Segmenting brain"), parent=self, style= wx.FRAME_FLOAT_ON_PARENT | wx.PD_CAN_ABORT | wx.PD_AUTO_HIDE | wx.PD_ELAPSED_TIME)
144 # self.pg_dialog.Bind(wx.EVT_BUTTON, self.OnStop) 171 # self.pg_dialog.Bind(wx.EVT_BUTTON, self.OnStop)
145 # self.pg_dialog.Show() 172 # self.pg_dialog.Show()
146 - self.segmenter.segment(image, prob_threshold, backend, use_gpu, self.SetProgress, self.AfterSegment) 173 + self.segmenter.segment(image, prob_threshold, backend, device_id, use_gpu, self.SetProgress, self.AfterSegment)
147 174
148 def OnStop(self, evt): 175 def OnStop(self, evt):
149 self.segmenter.stop = True 176 self.segmenter.stop = True
@@ -153,7 +180,12 @@ class BrainSegmenterDialog(wx.Dialog): @@ -153,7 +180,12 @@ class BrainSegmenterDialog(wx.Dialog):
153 self.btn_segment.Enable() 180 self.btn_segment.Enable()
154 evt.Skip() 181 evt.Skip()
155 182
  183 + def OnBtnClose(self, evt):
  184 + self.Close()
  185 +
156 def AfterSegment(self): 186 def AfterSegment(self):
  187 + self.btn_stop.Disable()
  188 + self.btn_segment.Disable()
157 Publisher.sendMessage('Reload actual slice') 189 Publisher.sendMessage('Reload actual slice')
158 190
159 def SetProgress(self, progress): 191 def SetProgress(self, progress):
invesalius/segmentation/brain/segment.py
@@ -50,14 +50,18 @@ class BrainSegmenter: @@ -50,14 +50,18 @@ class BrainSegmenter:
50 self.stop = False 50 self.stop = False
51 self.segmented = False 51 self.segmented = False
52 52
53 - def segment(self, image, prob_threshold, backend, use_gpu, progress_callback=None, after_segment=None): 53 + def segment(self, image, prob_threshold, backend, device_id, use_gpu, progress_callback=None, after_segment=None):
54 print("backend", backend) 54 print("backend", backend)
55 if backend.lower() == 'plaidml': 55 if backend.lower() == 'plaidml':
56 os.environ["KERAS_BACKEND"] = "plaidml.keras.backend" 56 os.environ["KERAS_BACKEND"] = "plaidml.keras.backend"
57 - device = utils.get_plaidml_devices(use_gpu)  
58 - os.environ["PLAIDML_DEVICE_IDS"] = device.id.decode("utf8") 57 + os.environ["PLAIDML_DEVICE_IDS"] = device_id
59 elif backend.lower() == 'theano': 58 elif backend.lower() == 'theano':
60 os.environ["KERAS_BACKEND"] = "theano" 59 os.environ["KERAS_BACKEND"] = "theano"
  60 + if use_gpu:
  61 + os.environ["THEANO_FLAGS"] = "device=cuda0"
  62 + print("Use GPU theano", os.environ["THEANO_FLAGS"])
  63 + else:
  64 + os.environ["THEANO_FLAGS"] = "device=cpu"
61 else: 65 else:
62 raise TypeError("Wrong backend") 66 raise TypeError("Wrong backend")
63 67
invesalius/segmentation/brain/utils.py
@@ -5,13 +5,18 @@ def get_plaidml_devices(gpu=False): @@ -5,13 +5,18 @@ def get_plaidml_devices(gpu=False):
5 plaidml.settings._setup_for_test(plaidml.settings.user_settings) 5 plaidml.settings._setup_for_test(plaidml.settings.user_settings)
6 plaidml.settings.experimental = True 6 plaidml.settings.experimental = True
7 devices, _ = plaidml.devices(ctx, limit=100, return_all=True) 7 devices, _ = plaidml.devices(ctx, limit=100, return_all=True)
8 - if gpu:  
9 - for device in devices:  
10 - if b"cuda" in device.description.lower():  
11 - return device  
12 - for device in devices:  
13 - if b"opencl" in device.description.lower():  
14 - return device 8 + out_devices = []
15 for device in devices: 9 for device in devices:
16 - if b"llvm" in device.description.lower():  
17 - return device 10 + points = 0
  11 + if b"cuda" in device.description.lower():
  12 + points += 1
  13 + if b"opencl" in device.description.lower():
  14 + points += 1
  15 + if b"nvidia" in device.description.lower():
  16 + points += 1
  17 + if b"amd" in device.description.lower():
  18 + points += 1
  19 + out_devices.append((points, device))
  20 +
  21 + out_devices.sort(reverse=True)
  22 + return {device.description.decode("utf8"): device.id.decode("utf8") for points, device in out_devices }