diff --git a/invesalius/gui/brain_seg_dialog.py b/invesalius/gui/brain_seg_dialog.py index 3d6df91..80a8e55 100644 --- a/invesalius/gui/brain_seg_dialog.py +++ b/invesalius/gui/brain_seg_dialog.py @@ -1,6 +1,7 @@ #!/usr/bin/env python # -*- coding: UTF-8 -*- +import importlib import os import pathlib import sys @@ -8,13 +9,8 @@ import sys import wx from wx.lib.pubsub import pub as Publisher -HAS_THEANO = True -HAS_PLAIDML = True - -try: - import theano -except ImportError: - HAS_THEANO = False +HAS_THEANO = bool(importlib.util.find_spec("theano")) +HAS_PLAIDML = bool(importlib.util.find_spec("plaidml")) # Linux if installed plaidml with pip3 install --user if sys.platform.startswith("linux"): @@ -29,13 +25,10 @@ elif sys.platform == "darwin": os.environ["RUNFILES_DIR"] = str(local_user_plaidml) os.environ["PLAIDML_NATIVE_PATH"] = str(pathlib.Path("/usr/local/lib/libplaidml.dylib").expanduser().absolute()) -try: - import plaidml -except ImportError: - HAS_PLAIDML = False import invesalius.data.slice_ as slc from invesalius.segmentation.brain import segment +from invesalius.segmentation.brain import utils @@ -49,11 +42,13 @@ class BrainSegmenterDialog(wx.Dialog): backends.append("Theano") self.segmenter = segment.BrainSegmenter() # self.pg_dialog = None - + self.plaidml_devices = utils.get_plaidml_devices() self.cb_backends = wx.ComboBox(self, wx.ID_ANY, choices=backends, value=backends[0], style=wx.CB_DROPDOWN | wx.CB_READONLY) w, h = self.CalcSizeFromTextSize("MM" * (1 + max(len(i) for i in backends))) self.cb_backends.SetMinClientSize((w, -1)) self.chk_use_gpu = wx.CheckBox(self, wx.ID_ANY, _("Use GPU")) + self.lbl_device = wx.StaticText(self, -1, _("Device")) + self.cb_devices = wx.ComboBox(self, wx.ID_ANY, choices=list(self.plaidml_devices.keys()), value=list(self.plaidml_devices.keys())[0],style=wx.CB_DROPDOWN | wx.CB_READONLY) self.sld_threshold = wx.Slider(self, wx.ID_ANY, 75, 0, 100) w, h = self.CalcSizeFromTextSize("M" * 20) self.sld_threshold.SetMinClientSize((w, -1)) @@ -64,6 +59,7 @@ class BrainSegmenterDialog(wx.Dialog): self.btn_segment = wx.Button(self, wx.ID_ANY, _("Segment")) self.btn_stop = wx.Button(self, wx.ID_ANY, _("Stop")) self.btn_stop.Disable() + self.btn_close = wx.Button(self, wx.ID_CLOSE) self.txt_threshold.SetValue("{:3d}%".format(self.sld_threshold.GetValue())) @@ -76,9 +72,13 @@ class BrainSegmenterDialog(wx.Dialog): sizer_backends = wx.BoxSizer(wx.HORIZONTAL) label_1 = wx.StaticText(self, wx.ID_ANY, _("Backend")) sizer_backends.Add(label_1, 0, wx.ALIGN_CENTER, 0) - sizer_backends.Add(self.cb_backends, 1, wx.LEFT, 5) + sizer_backends.Add(self.cb_backends, 1, wx.LEFT, 0) main_sizer.Add(sizer_backends, 0, wx.ALL | wx.EXPAND, 5) main_sizer.Add(self.chk_use_gpu, 0, wx.ALL, 5) + sizer_devices = wx.BoxSizer(wx.HORIZONTAL) + sizer_devices.Add(self.lbl_device, 0, wx.ALIGN_CENTER, 0) + sizer_devices.Add(self.cb_devices, 1, wx.LEFT, 5) + main_sizer.Add(sizer_devices, 0, wx.ALL | wx.EXPAND, 5) label_5 = wx.StaticText(self, wx.ID_ANY, _("Level of certainty")) main_sizer.Add(label_5, 0, wx.ALL, 5) sizer_3.Add(self.sld_threshold, 1, wx.ALIGN_CENTER | wx.BOTTOM | wx.EXPAND | wx.LEFT | wx.RIGHT, 5) @@ -86,20 +86,28 @@ class BrainSegmenterDialog(wx.Dialog): main_sizer.Add(sizer_3, 0, wx.EXPAND, 0) main_sizer.Add(self.progress, 0, wx.EXPAND | wx.ALL, 5) sizer_buttons = wx.BoxSizer(wx.HORIZONTAL) + sizer_buttons.Add(self.btn_close, 0, wx.ALIGN_BOTTOM | wx.ALIGN_RIGHT | wx.ALL, 5) sizer_buttons.Add(self.btn_stop, 0, wx.ALIGN_BOTTOM | wx.ALIGN_RIGHT | wx.ALL, 5) sizer_buttons.Add(self.btn_segment, 0, wx.ALIGN_BOTTOM | wx.ALIGN_RIGHT | wx.ALL, 5) main_sizer.Add(sizer_buttons, 0, wx.ALIGN_BOTTOM | wx.ALIGN_RIGHT | wx.ALL, 0) self.SetSizer(main_sizer) main_sizer.Fit(self) main_sizer.SetSizeHints(self) + + self.main_sizer = main_sizer + + self.OnSetBackend() + self.Layout() self.Centre() def __set_events(self): + self.cb_backends.Bind(wx.EVT_COMBOBOX, self.OnSetBackend) self.sld_threshold.Bind(wx.EVT_SCROLL, self.OnScrollThreshold) self.txt_threshold.Bind(wx.EVT_KILL_FOCUS, self.OnKillFocus) self.btn_segment.Bind(wx.EVT_BUTTON, self.OnSegment) self.btn_stop.Bind(wx.EVT_BUTTON, self.OnStop) + self.btn_close.Bind(wx.EVT_BUTTON, self.OnBtnClose) self.Bind(wx.EVT_CLOSE, self.OnClose) def CalcSizeFromTextSize(self, text): @@ -108,6 +116,19 @@ class BrainSegmenterDialog(wx.Dialog): width, height = dc.GetTextExtent(text) return width, height + def OnSetBackend(self, evt=None): + if self.cb_backends.GetValue().lower() == "plaidml": + self.lbl_device.Show() + self.cb_devices.Show() + self.chk_use_gpu.Hide() + else: + self.lbl_device.Hide() + self.cb_devices.Hide() + self.chk_use_gpu.Show() + + self.main_sizer.Fit(self) + self.main_sizer.SetSizeHints(self) + def OnScrollThreshold(self, evt): value = self.sld_threshold.GetValue() self.txt_threshold.SetValue("{:3d}%".format(self.sld_threshold.GetValue())) @@ -136,14 +157,20 @@ class BrainSegmenterDialog(wx.Dialog): def OnSegment(self, evt): image = slc.Slice().matrix backend = self.cb_backends.GetValue() + try: + device_id = self.plaidml_devices[self.cb_devices.GetValue()] + except KeyError: + device_id = "llvm_cpu.0" use_gpu = self.chk_use_gpu.GetValue() prob_threshold = self.sld_threshold.GetValue() / 100.0 self.btn_stop.Enable() self.btn_segment.Disable() + + print(device_id) # self.pg_dialog = wx.ProgressDialog(_("Brain segmenter"), _("Segmenting brain"), parent=self, style= wx.FRAME_FLOAT_ON_PARENT | wx.PD_CAN_ABORT | wx.PD_AUTO_HIDE | wx.PD_ELAPSED_TIME) # self.pg_dialog.Bind(wx.EVT_BUTTON, self.OnStop) # self.pg_dialog.Show() - self.segmenter.segment(image, prob_threshold, backend, use_gpu, self.SetProgress, self.AfterSegment) + self.segmenter.segment(image, prob_threshold, backend, device_id, use_gpu, self.SetProgress, self.AfterSegment) def OnStop(self, evt): self.segmenter.stop = True @@ -153,7 +180,12 @@ class BrainSegmenterDialog(wx.Dialog): self.btn_segment.Enable() evt.Skip() + def OnBtnClose(self, evt): + self.Close() + def AfterSegment(self): + self.btn_stop.Disable() + self.btn_segment.Disable() Publisher.sendMessage('Reload actual slice') def SetProgress(self, progress): diff --git a/invesalius/segmentation/brain/segment.py b/invesalius/segmentation/brain/segment.py index 400250d..f483b1d 100644 --- a/invesalius/segmentation/brain/segment.py +++ b/invesalius/segmentation/brain/segment.py @@ -50,14 +50,18 @@ class BrainSegmenter: self.stop = False self.segmented = False - def segment(self, image, prob_threshold, backend, use_gpu, progress_callback=None, after_segment=None): + def segment(self, image, prob_threshold, backend, device_id, use_gpu, progress_callback=None, after_segment=None): print("backend", backend) if backend.lower() == 'plaidml': os.environ["KERAS_BACKEND"] = "plaidml.keras.backend" - device = utils.get_plaidml_devices(use_gpu) - os.environ["PLAIDML_DEVICE_IDS"] = device.id.decode("utf8") + os.environ["PLAIDML_DEVICE_IDS"] = device_id elif backend.lower() == 'theano': os.environ["KERAS_BACKEND"] = "theano" + if use_gpu: + os.environ["THEANO_FLAGS"] = "device=cuda0" + print("Use GPU theano", os.environ["THEANO_FLAGS"]) + else: + os.environ["THEANO_FLAGS"] = "device=cpu" else: raise TypeError("Wrong backend") diff --git a/invesalius/segmentation/brain/utils.py b/invesalius/segmentation/brain/utils.py index 13075da..a86a142 100644 --- a/invesalius/segmentation/brain/utils.py +++ b/invesalius/segmentation/brain/utils.py @@ -5,13 +5,18 @@ def get_plaidml_devices(gpu=False): plaidml.settings._setup_for_test(plaidml.settings.user_settings) plaidml.settings.experimental = True devices, _ = plaidml.devices(ctx, limit=100, return_all=True) - if gpu: - for device in devices: - if b"cuda" in device.description.lower(): - return device - for device in devices: - if b"opencl" in device.description.lower(): - return device + out_devices = [] for device in devices: - if b"llvm" in device.description.lower(): - return device + points = 0 + if b"cuda" in device.description.lower(): + points += 1 + if b"opencl" in device.description.lower(): + points += 1 + if b"nvidia" in device.description.lower(): + points += 1 + if b"amd" in device.description.lower(): + points += 1 + out_devices.append((points, device)) + + out_devices.sort(reverse=True) + return {device.description.decode("utf8"): device.id.decode("utf8") for points, device in out_devices } -- libgit2 0.21.2