Commit b6ae340436ae6e2556de9f1521b7ffc029cac8a6

Authored by Renan
Committed by GitHub
2 parents 75cc013e 2a9da902
Exists in master

Merge branch 'master' into multimodal_tracking

invesalius/constants.py
... ... @@ -830,8 +830,14 @@ TREKKER_CONFIG = {'seed_max': 1, 'step_size': 0.1, 'min_fod': 0.1, 'probe_qualit
830 830  
831 831 MARKER_FILE_MAGICK_STRING = "INVESALIUS3_MARKER_FILE_"
832 832 CURRENT_MARKER_FILE_VERSION = 0
  833 +
833 834 WILDCARD_MARKER_FILES = _("Marker scanner coord files (*.mkss)|*.mkss")
834 835  
  836 +# Serial port
  837 +BAUD_RATES = [300, 1200, 2400, 4800, 9600, 19200, 38400, 57600, 115200]
  838 +BAUD_RATE_DEFAULT_SELECTION = 4
  839 +
  840 +#Robot
835 841 ROBOT_ElFIN_IP = ['Select robot IP:', '143.107.220.251', '169.254.153.251', '127.0.0.1']
836 842 ROBOT_ElFIN_PORT = 10003
837 843  
... ...
invesalius/data/serial_port_connection.py
... ... @@ -28,7 +28,7 @@ from invesalius.pubsub import pub as Publisher
28 28 class SerialPortConnection(threading.Thread):
29 29 BINARY_PULSE = b'\x01'
30 30  
31   - def __init__(self, port, serial_port_queue, event, sleep_nav):
  31 + def __init__(self, com_port, baud_rate, serial_port_queue, event, sleep_nav):
32 32 """
33 33 Thread created to communicate using the serial port to interact with software during neuronavigation.
34 34 """
... ... @@ -37,28 +37,29 @@ class SerialPortConnection(threading.Thread):
37 37 self.connection = None
38 38 self.stylusplh = False
39 39  
40   - self.port = port
  40 + self.com_port = com_port
  41 + self.baud_rate = baud_rate
41 42 self.serial_port_queue = serial_port_queue
42 43 self.event = event
43 44 self.sleep_nav = sleep_nav
44 45  
45 46 def Connect(self):
46   - if self.port is None:
  47 + if self.com_port is None:
47 48 print("Serial port init error: COM port is unset.")
48 49 return
49 50 try:
50 51 import serial
51   - self.connection = serial.Serial(self.port, baudrate=115200, timeout=0)
52   - print("Connection to port {} opened.".format(self.port))
  52 + self.connection = serial.Serial(self.com_port, baudrate=self.baud_rate, timeout=0)
  53 + print("Connection to port {} opened.".format(self.com_port))
53 54  
54 55 Publisher.sendMessage('Serial port connection', state=True)
55 56 except:
56   - print("Serial port init error: Connecting to port {} failed.".format(self.port))
  57 + print("Serial port init error: Connecting to port {} failed.".format(self.com_port))
57 58  
58 59 def Disconnect(self):
59 60 if self.connection:
60 61 self.connection.close()
61   - print("Connection to port {} closed.".format(self.port))
  62 + print("Connection to port {} closed.".format(self.com_port))
62 63  
63 64 Publisher.sendMessage('Serial port connection', state=False)
64 65  
... ... @@ -74,12 +75,11 @@ class SerialPortConnection(threading.Thread):
74 75 trigger_on = False
75 76 try:
76 77 lines = self.connection.readlines()
  78 + if lines:
  79 + trigger_on = True
77 80 except:
78 81 print("Error: Serial port could not be read.")
79 82  
80   - if lines:
81   - trigger_on = True
82   -
83 83 if self.stylusplh:
84 84 trigger_on = True
85 85 self.stylusplh = False
... ...
invesalius/data/trackers.py
... ... @@ -299,7 +299,7 @@ def PlhSerialConnection(tracker_id):
299 299 import serial
300 300 from wx import ID_OK
301 301 trck_init = None
302   - dlg_port = dlg.SetCOMport()
  302 + dlg_port = dlg.SetCOMPort(select_baud_rate=False)
303 303 if dlg_port.ShowModal() == ID_OK:
304 304 com_port = dlg_port.GetValue()
305 305 try:
... ...
invesalius/gui/brain_seg_dialog.py
... ... @@ -22,6 +22,22 @@ HAS_THEANO = bool(importlib.util.find_spec("theano"))
22 22 HAS_PLAIDML = bool(importlib.util.find_spec("plaidml"))
23 23 PLAIDML_DEVICES = {}
24 24  
  25 +try:
  26 + import torch
  27 + HAS_TORCH = True
  28 +except ImportError:
  29 + HAS_TORCH = False
  30 +
  31 +if HAS_TORCH:
  32 + TORCH_DEVICES = {}
  33 + if torch.cuda.is_available():
  34 + for i in range(torch.cuda.device_count()):
  35 + name = torch.cuda.get_device_name()
  36 + device_id = f'cuda:{i}'
  37 + TORCH_DEVICES[name] = device_id
  38 + TORCH_DEVICES['CPU'] = 'cpu'
  39 +
  40 +
25 41  
26 42 if HAS_PLAIDML:
27 43 with multiprocessing.Pool(1) as p:
... ... @@ -43,12 +59,15 @@ class BrainSegmenterDialog(wx.Dialog):
43 59 style=wx.DEFAULT_DIALOG_STYLE | wx.FRAME_FLOAT_ON_PARENT,
44 60 )
45 61 backends = []
  62 + if HAS_TORCH:
  63 + backends.append("Pytorch")
46 64 if HAS_PLAIDML:
47 65 backends.append("PlaidML")
48 66 if HAS_THEANO:
49 67 backends.append("Theano")
50 68 # self.segmenter = segment.BrainSegmenter()
51 69 # self.pg_dialog = None
  70 + self.torch_devices = TORCH_DEVICES
52 71 self.plaidml_devices = PLAIDML_DEVICES
53 72  
54 73 self.ps = None
... ... @@ -65,13 +84,19 @@ class BrainSegmenterDialog(wx.Dialog):
65 84 w, h = self.CalcSizeFromTextSize("MM" * (1 + max(len(i) for i in backends)))
66 85 self.cb_backends.SetMinClientSize((w, -1))
67 86 self.chk_use_gpu = wx.CheckBox(self, wx.ID_ANY, _("Use GPU"))
68   - if HAS_PLAIDML:
  87 + if HAS_TORCH or HAS_PLAIDML:
  88 + if HAS_TORCH:
  89 + choices = list(self.torch_devices.keys())
  90 + value = choices[0]
  91 + else:
  92 + choices = list(self.plaidml_devices.keys())
  93 + value = choices[0]
69 94 self.lbl_device = wx.StaticText(self, -1, _("Device"))
70 95 self.cb_devices = wx.ComboBox(
71 96 self,
72 97 wx.ID_ANY,
73   - choices=list(self.plaidml_devices.keys()),
74   - value=list(self.plaidml_devices.keys())[0],
  98 + choices=choices,
  99 + value=value,
75 100 style=wx.CB_DROPDOWN | wx.CB_READONLY,
76 101 )
77 102 self.sld_threshold = wx.Slider(self, wx.ID_ANY, 75, 0, 100)
... ... @@ -109,7 +134,7 @@ class BrainSegmenterDialog(wx.Dialog):
109 134 main_sizer.Add(sizer_backends, 0, wx.ALL | wx.EXPAND, 5)
110 135 main_sizer.Add(self.chk_use_gpu, 0, wx.ALL, 5)
111 136 sizer_devices = wx.BoxSizer(wx.HORIZONTAL)
112   - if HAS_PLAIDML:
  137 + if HAS_TORCH or HAS_PLAIDML:
113 138 sizer_devices.Add(self.lbl_device, 0, wx.ALIGN_CENTER, 0)
114 139 sizer_devices.Add(self.cb_devices, 1, wx.LEFT, 5)
115 140 main_sizer.Add(sizer_devices, 0, wx.ALL | wx.EXPAND, 5)
... ... @@ -177,8 +202,21 @@ class BrainSegmenterDialog(wx.Dialog):
177 202 return width, height
178 203  
179 204 def OnSetBackend(self, evt=None):
180   - if self.cb_backends.GetValue().lower() == "plaidml":
  205 + if self.cb_backends.GetValue().lower() == "pytorch":
  206 + if HAS_TORCH:
  207 + choices = list(self.torch_devices.keys())
  208 + self.cb_devices.Clear()
  209 + self.cb_devices.SetItems(choices)
  210 + self.cb_devices.SetValue(choices[0])
  211 + self.lbl_device.Show()
  212 + self.cb_devices.Show()
  213 + self.chk_use_gpu.Hide()
  214 + elif self.cb_backends.GetValue().lower() == "plaidml":
181 215 if HAS_PLAIDML:
  216 + choices = list(self.plaidml_devices.keys())
  217 + self.cb_devices.Clear()
  218 + self.cb_devices.SetItems(choices)
  219 + self.cb_devices.SetValue(choices[0])
182 220 self.lbl_device.Show()
183 221 self.cb_devices.Show()
184 222 self.chk_use_gpu.Hide()
... ... @@ -216,10 +254,16 @@ class BrainSegmenterDialog(wx.Dialog):
216 254 self.elapsed_time_timer.Start(1000)
217 255 image = slc.Slice().matrix
218 256 backend = self.cb_backends.GetValue()
219   - try:
220   - device_id = self.plaidml_devices[self.cb_devices.GetValue()]
221   - except (KeyError, AttributeError):
222   - device_id = "llvm_cpu.0"
  257 + if backend.lower() == "pytorch":
  258 + try:
  259 + device_id = self.torch_devices[self.cb_devices.GetValue()]
  260 + except (KeyError, AttributeError):
  261 + device_id = "cpu"
  262 + else:
  263 + try:
  264 + device_id = self.plaidml_devices[self.cb_devices.GetValue()]
  265 + except (KeyError, AttributeError):
  266 + device_id = "llvm_cpu.0"
223 267 apply_wwwl = self.chk_apply_wwwl.GetValue()
224 268 create_new_mask = self.chk_new_mask.GetValue()
225 269 use_gpu = self.chk_use_gpu.GetValue()
... ...
invesalius/gui/dialogs.py
... ... @@ -4632,7 +4632,8 @@ class SetNDIconfigs(wx.Dialog):
4632 4632 self._init_gui()
4633 4633  
4634 4634 def serial_ports(self):
4635   - """ Lists serial port names and pre-select the description containing NDI
  4635 + """
  4636 + Lists serial port names and pre-select the description containing NDI
4636 4637 """
4637 4638 import serial.tools.list_ports
4638 4639  
... ... @@ -4748,13 +4749,16 @@ class SetNDIconfigs(wx.Dialog):
4748 4749 return self.com_ports.GetString(self.com_ports.GetSelection()).encode(const.FS_ENCODE), fn_probe, fn_ref, fn_obj
4749 4750  
4750 4751  
4751   -class SetCOMport(wx.Dialog):
4752   - def __init__(self, title=_("Select COM port")):
4753   - wx.Dialog.__init__(self, wx.GetApp().GetTopWindow(), -1, title, style=wx.DEFAULT_DIALOG_STYLE|wx.FRAME_FLOAT_ON_PARENT|wx.STAY_ON_TOP)
  4752 +class SetCOMPort(wx.Dialog):
  4753 + def __init__(self, select_baud_rate, title=_("Select COM port")):
  4754 + wx.Dialog.__init__(self, wx.GetApp().GetTopWindow(), -1, title, style=wx.DEFAULT_DIALOG_STYLE | wx.FRAME_FLOAT_ON_PARENT | wx.STAY_ON_TOP)
  4755 +
  4756 + self.select_baud_rate = select_baud_rate
4754 4757 self._init_gui()
4755 4758  
4756 4759 def serial_ports(self):
4757   - """ Lists serial port names
  4760 + """
  4761 + Lists serial port names
4758 4762 """
4759 4763 import serial.tools.list_ports
4760 4764 if sys.platform.startswith('win'):
... ... @@ -4764,12 +4768,26 @@ class SetCOMport(wx.Dialog):
4764 4768 return ports
4765 4769  
4766 4770 def _init_gui(self):
4767   - self.com_ports = wx.ComboBox(self, -1, style=wx.CB_DROPDOWN|wx.CB_READONLY)
  4771 + # COM port selection
4768 4772 ports = self.serial_ports()
4769   - self.com_ports.Append(ports)
  4773 + self.com_port_dropdown = wx.ComboBox(self, -1, choices=ports, style=wx.CB_DROPDOWN | wx.CB_READONLY)
  4774 + self.com_port_dropdown.SetSelection(0)
  4775 +
  4776 + com_port_text_and_dropdown = wx.BoxSizer(wx.VERTICAL)
  4777 + com_port_text_and_dropdown.Add(wx.StaticText(self, wx.ID_ANY, "COM port"), 0, wx.TOP | wx.RIGHT,5)
  4778 + com_port_text_and_dropdown.Add(self.com_port_dropdown, 0, wx.EXPAND)
  4779 +
  4780 + # Baud rate selection
  4781 + if self.select_baud_rate:
  4782 + baud_rates_as_strings = [str(baud_rate) for baud_rate in const.BAUD_RATES]
  4783 + self.baud_rate_dropdown = wx.ComboBox(self, -1, choices=baud_rates_as_strings, style=wx.CB_DROPDOWN | wx.CB_READONLY)
  4784 + self.baud_rate_dropdown.SetSelection(const.BAUD_RATE_DEFAULT_SELECTION)
4770 4785  
4771   - # self.goto_orientation.SetSelection(cb_init)
  4786 + baud_rate_text_and_dropdown = wx.BoxSizer(wx.VERTICAL)
  4787 + baud_rate_text_and_dropdown.Add(wx.StaticText(self, wx.ID_ANY, "Baud rate"), 0, wx.TOP | wx.RIGHT,5)
  4788 + baud_rate_text_and_dropdown.Add(self.baud_rate_dropdown, 0, wx.EXPAND)
4772 4789  
  4790 + # OK and Cancel buttons
4773 4791 btn_ok = wx.Button(self, wx.ID_OK)
4774 4792 btn_ok.SetHelpText("")
4775 4793 btn_ok.SetDefault()
... ... @@ -4782,10 +4800,16 @@ class SetCOMport(wx.Dialog):
4782 4800 btnsizer.AddButton(btn_cancel)
4783 4801 btnsizer.Realize()
4784 4802  
  4803 + # Set up the main sizer
4785 4804 main_sizer = wx.BoxSizer(wx.VERTICAL)
4786 4805  
4787 4806 main_sizer.Add((5, 5))
4788   - main_sizer.Add(self.com_ports, 1, wx.EXPAND|wx.LEFT|wx.RIGHT, 5)
  4807 + main_sizer.Add(com_port_text_and_dropdown, 1, wx.EXPAND | wx.LEFT | wx.RIGHT, 5)
  4808 +
  4809 + if self.select_baud_rate:
  4810 + main_sizer.Add((5, 5))
  4811 + main_sizer.Add(baud_rate_text_and_dropdown, 1, wx.EXPAND | wx.LEFT | wx.RIGHT, 5)
  4812 +
4789 4813 main_sizer.Add((5, 5))
4790 4814 main_sizer.Add(btnsizer, 0, wx.EXPAND)
4791 4815 main_sizer.Add((5, 5))
... ... @@ -4796,7 +4820,14 @@ class SetCOMport(wx.Dialog):
4796 4820 self.CenterOnParent()
4797 4821  
4798 4822 def GetValue(self):
4799   - return self.com_ports.GetString(self.com_ports.GetSelection())
  4823 + com_port = self.com_port_dropdown.GetString(self.com_port_dropdown.GetSelection())
  4824 +
  4825 + if self.select_baud_rate:
  4826 + baud_rate = self.baud_rate_dropdown.GetString(self.baud_rate_dropdown.GetSelection())
  4827 + else:
  4828 + baud_rate = None
  4829 +
  4830 + return com_port, baud_rate
4800 4831  
4801 4832  
4802 4833 class ManualWWWLDialog(wx.Dialog):
... ...
invesalius/gui/task_navigator.py
... ... @@ -234,8 +234,8 @@ class InnerFoldPanel(wx.Panel):
234 234 checkcamera.Bind(wx.EVT_CHECKBOX, self.OnVolumeCamera)
235 235 self.checkcamera = checkcamera
236 236  
237   - # Check box to create markers from serial port
238   - tooltip = wx.ToolTip(_("Enable serial port communication for creating markers"))
  237 + # Check box to use serial port to trigger pulse signal and create markers
  238 + tooltip = wx.ToolTip(_("Enable serial port communication to trigger pulse and create markers"))
239 239 checkbox_serial_port = wx.CheckBox(self, -1, _('Serial port'))
240 240 checkbox_serial_port.SetToolTip(tooltip)
241 241 checkbox_serial_port.SetValue(False)
... ... @@ -297,14 +297,20 @@ class InnerFoldPanel(wx.Panel):
297 297 self.checkobj.Enable(True)
298 298  
299 299 def OnEnableSerialPort(self, evt, ctrl):
300   - com_port = None
301 300 if ctrl.GetValue():
302 301 from wx import ID_OK
303   - dlg_port = dlg.SetCOMport()
304   - if dlg_port.ShowModal() == ID_OK:
305   - com_port = dlg_port.GetValue()
  302 + dlg_port = dlg.SetCOMPort(select_baud_rate=False)
306 303  
307   - Publisher.sendMessage('Update serial port', serial_port=com_port)
  304 + if dlg_port.ShowModal() != ID_OK:
  305 + ctrl.SetValue(False)
  306 + return
  307 +
  308 + com_port = dlg_port.GetValue()
  309 + baud_rate = 115200
  310 +
  311 + Publisher.sendMessage('Update serial port', serial_port_in_use=True, com_port=com_port, baud_rate=baud_rate)
  312 + else:
  313 + Publisher.sendMessage('Update serial port', serial_port_in_use=False)
308 314  
309 315 def OnShowObject(self, evt=None, flag=None, obj_name=None, polydata=None, use_default_object=True):
310 316 if not evt:
... ... @@ -503,7 +509,6 @@ class NeuronavigationPanel(wx.Panel):
503 509 Publisher.subscribe(self.LoadImageFiducials, 'Load image fiducials')
504 510 Publisher.subscribe(self.SetImageFiducial, 'Set image fiducial')
505 511 Publisher.subscribe(self.SetTrackerFiducial, 'Set tracker fiducial')
506   - Publisher.subscribe(self.UpdateSerialPort, 'Update serial port')
507 512 Publisher.subscribe(self.UpdateTrackObjectState, 'Update track object state')
508 513 Publisher.subscribe(self.UpdateImageCoordinates, 'Set cross focal point')
509 514 Publisher.subscribe(self.OnDisconnectTracker, 'Disconnect tracker')
... ... @@ -627,9 +632,6 @@ class NeuronavigationPanel(wx.Panel):
627 632 def UpdateTrackObjectState(self, evt=None, flag=None, obj_name=None, polydata=None, use_default_object=True):
628 633 self.navigation.track_obj = flag
629 634  
630   - def UpdateSerialPort(self, serial_port):
631   - self.navigation.serial_port = serial_port
632   -
633 635 def ResetICP(self):
634 636 self.icp.ResetICP()
635 637 self.checkbox_icp.Enable(False)
... ...
invesalius/inv_paths.py
... ... @@ -27,6 +27,7 @@ CONF_DIR = pathlib.Path(os.environ.get("XDG_CONFIG_HOME", USER_DIR.joinpath(".co
27 27 USER_INV_DIR = CONF_DIR.joinpath("invesalius")
28 28 USER_PRESET_DIR = USER_INV_DIR.joinpath("presets")
29 29 USER_LOG_DIR = USER_INV_DIR.joinpath("logs")
  30 +USER_DL_WEIGHTS = USER_INV_DIR.joinpath("deep_learning/weights/")
30 31 USER_RAYCASTING_PRESETS_DIRECTORY = USER_PRESET_DIR.joinpath("raycasting")
31 32 TEMP_DIR = tempfile.gettempdir()
32 33  
... ... @@ -97,6 +98,7 @@ def create_conf_folders():
97 98 USER_INV_DIR.mkdir(parents=True, exist_ok=True)
98 99 USER_PRESET_DIR.mkdir(parents=True, exist_ok=True)
99 100 USER_LOG_DIR.mkdir(parents=True, exist_ok=True)
  101 + USER_DL_WEIGHTS.mkdir(parents=True, exist_ok=True)
100 102 USER_PLUGINS_DIRECTORY.mkdir(parents=True, exist_ok=True)
101 103  
102 104  
... ...
invesalius/navigation/navigation.py
... ... @@ -171,7 +171,9 @@ class Navigation():
171 171 self.sleep_nav = const.SLEEP_NAVIGATION
172 172  
173 173 # Serial port
174   - self.serial_port = None
  174 + self.serial_port_in_use = False
  175 + self.com_port = None
  176 + self.baud_rate = None
175 177 self.serial_port_connection = None
176 178  
177 179 # During navigation
... ... @@ -181,6 +183,7 @@ class Navigation():
181 183  
182 184 def __bind_events(self):
183 185 Publisher.subscribe(self.CoilAtTarget, 'Coil at target')
  186 + Publisher.subscribe(self.UpdateSerialPort, 'Update serial port')
184 187  
185 188 def CoilAtTarget(self, state):
186 189 self.coil_at_target = state
... ... @@ -189,8 +192,10 @@ class Navigation():
189 192 self.sleep_nav = sleep
190 193 self.serial_port_connection.sleep_nav = sleep
191 194  
192   - def SerialPortEnabled(self):
193   - return self.serial_port is not None
  195 + def UpdateSerialPort(self, serial_port_in_use, com_port=None, baud_rate=None):
  196 + self.serial_port_in_use = serial_port_in_use
  197 + self.com_port = com_port
  198 + self.baud_rate = baud_rate
194 199  
195 200 def SetReferenceMode(self, value):
196 201 self.ref_mode_id = value
... ... @@ -218,7 +223,7 @@ class Navigation():
218 223 return fre, fre <= const.FIDUCIAL_REGISTRATION_ERROR_THRESHOLD
219 224  
220 225 def PedalStateChanged(self, state):
221   - if state is True and self.coil_at_target and self.SerialPortEnabled():
  226 + if state is True and self.coil_at_target and self.serial_port_in_use:
222 227 self.serial_port_connection.SendPulse()
223 228  
224 229 def StartNavigation(self, tracker):
... ... @@ -230,7 +235,7 @@ class Navigation():
230 235 if self.event.is_set():
231 236 self.event.clear()
232 237  
233   - vis_components = [self.SerialPortEnabled(), self.view_tracts, self.peel_loaded]
  238 + vis_components = [self.serial_port_in_use, self.view_tracts, self.peel_loaded]
234 239 vis_queues = [self.coord_queue, self.serial_port_queue, self.tracts_queue, self.icp_queue, self.robottarget_queue]
235 240  
236 241 Publisher.sendMessage("Navigation status", nav_status=True, vis_status=vis_components)
... ... @@ -279,12 +284,13 @@ class Navigation():
279 284  
280 285 if not errors:
281 286 #TODO: Test the serial port thread
282   - if self.SerialPortEnabled():
  287 + if self.serial_port_in_use:
283 288 self.serial_port_connection = spc.SerialPortConnection(
284   - self.serial_port,
285   - self.serial_port_queue,
286   - self.event,
287   - self.sleep_nav,
  289 + com_port=self.com_port,
  290 + baud_rate=self.baud_rate,
  291 + serial_port_queue=self.serial_port_queue,
  292 + event=self.event,
  293 + sleep_nav=self.sleep_nav,
288 294 )
289 295 self.serial_port_connection.Connect()
290 296 jobs_list.append(self.serial_port_connection)
... ... @@ -330,7 +336,7 @@ class Navigation():
330 336 if self.serial_port_connection is not None:
331 337 self.serial_port_connection.join()
332 338  
333   - if self.SerialPortEnabled():
  339 + if self.serial_port_in_use:
334 340 self.serial_port_queue.clear()
335 341 self.serial_port_queue.join()
336 342  
... ... @@ -341,5 +347,5 @@ class Navigation():
341 347 self.tracts_queue.clear()
342 348 self.tracts_queue.join()
343 349  
344   - vis_components = [self.SerialPortEnabled(), self.view_tracts, self.peel_loaded]
  350 + vis_components = [self.serial_port_in_use, self.view_tracts, self.peel_loaded]
345 351 Publisher.sendMessage("Navigation status", nav_status=False, vis_status=vis_components)
... ...
invesalius/net/utils.py 0 → 100644
... ... @@ -0,0 +1,48 @@
  1 +from urllib.error import HTTPError
  2 +from urllib.request import urlopen, Request
  3 +from urllib.parse import urlparse
  4 +import pathlib
  5 +import tempfile
  6 +import typing
  7 +import hashlib
  8 +import os
  9 +import shutil
  10 +
  11 +def download_url_to_file(url: str, dst: pathlib.Path, hash: str = None, callback: typing.Callable[[float], None] = None):
  12 + file_size = None
  13 + total_downloaded = 0
  14 + if hash is not None:
  15 + calc_hash = hashlib.sha256()
  16 + req = Request(url)
  17 + response = urlopen(req)
  18 + meta = response.info()
  19 + if hasattr(meta, "getheaders"):
  20 + content_length = meta.getheaders("Content-Length")
  21 + else:
  22 + content_length = meta.get_all("Content-Length")
  23 +
  24 + if content_length is not None and len(content_length) > 0:
  25 + file_size = int(content_length[0])
  26 + dst.parent.mkdir(parents=True, exist_ok=True)
  27 + f = tempfile.NamedTemporaryFile(delete=False, dir=dst.parent)
  28 + try:
  29 + while True:
  30 + buffer = response.read(8192)
  31 + if len(buffer) == 0:
  32 + break
  33 + total_downloaded += len(buffer)
  34 + f.write(buffer)
  35 + if hash:
  36 + calc_hash.update(buffer)
  37 + if callback is not None:
  38 + callback(100 * total_downloaded/file_size)
  39 + f.close()
  40 + if hash is not None:
  41 + digest = calc_hash.hexdigest()
  42 + if digest != hash:
  43 + raise RuntimeError(f'Invalid hash value (expected "{hash}", got "{digest}")')
  44 + shutil.move(f.name, dst)
  45 + finally:
  46 + f.close()
  47 + if os.path.exists(f.name):
  48 + os.remove(f.name)
... ...
invesalius/segmentation/brain/model.py 0 → 100644
... ... @@ -0,0 +1,149 @@
  1 +from collections import OrderedDict
  2 +
  3 +import torch
  4 +import torch.nn as nn
  5 +
  6 +SIZE = 48
  7 +
  8 +class Unet3D(nn.Module):
  9 + # Based on https://github.com/mateuszbuda/brain-segmentation-pytorch/blob/master/unet.py
  10 + def __init__(self, in_channels=1, out_channels=1, init_features=8):
  11 + super().__init__()
  12 + features = init_features
  13 +
  14 + self.encoder1 = self._block(
  15 + in_channels, features=features, padding=2, name="enc1"
  16 + )
  17 + self.pool1 = nn.MaxPool3d(kernel_size=2, stride=2)
  18 +
  19 + self.encoder2 = self._block(
  20 + features, features=features * 2, padding=2, name="enc2"
  21 + )
  22 + self.pool2 = nn.MaxPool3d(kernel_size=2, stride=2)
  23 +
  24 + self.encoder3 = self._block(
  25 + features * 2, features=features * 4, padding=2, name="enc3"
  26 + )
  27 + self.pool3 = nn.MaxPool3d(kernel_size=2, stride=2)
  28 +
  29 + self.encoder4 = self._block(
  30 + features * 4, features=features * 8, padding=2, name="enc4"
  31 + )
  32 + self.pool4 = nn.MaxPool3d(kernel_size=2, stride=2)
  33 +
  34 + self.bottleneck = self._block(
  35 + features * 8, features=features * 16, padding=2, name="bottleneck"
  36 + )
  37 +
  38 + self.upconv4 = nn.ConvTranspose3d(
  39 + features * 16, features * 8, kernel_size=4, stride=2, padding=1
  40 + )
  41 + self.decoder4 = self._block(
  42 + features * 16, features=features * 8, padding=2, name="dec4"
  43 + )
  44 +
  45 + self.upconv3 = nn.ConvTranspose3d(
  46 + features * 8, features * 4, kernel_size=4, stride=2, padding=1
  47 + )
  48 + self.decoder3 = self._block(
  49 + features * 8, features=features * 4, padding=2, name="dec4"
  50 + )
  51 +
  52 + self.upconv2 = nn.ConvTranspose3d(
  53 + features * 4, features * 2, kernel_size=4, stride=2, padding=1
  54 + )
  55 + self.decoder2 = self._block(
  56 + features * 4, features=features * 2, padding=2, name="dec4"
  57 + )
  58 +
  59 + self.upconv1 = nn.ConvTranspose3d(
  60 + features * 2, features, kernel_size=4, stride=2, padding=1
  61 + )
  62 + self.decoder1 = self._block(
  63 + features * 2, features=features, padding=2, name="dec4"
  64 + )
  65 +
  66 + self.conv = nn.Conv3d(
  67 + in_channels=features, out_channels=out_channels, kernel_size=1
  68 + )
  69 +
  70 + def forward(self, img):
  71 + enc1 = self.encoder1(img)
  72 + enc2 = self.encoder2(self.pool1(enc1))
  73 + enc3 = self.encoder3(self.pool2(enc2))
  74 + enc4 = self.encoder4(self.pool3(enc3))
  75 +
  76 + bottleneck = self.bottleneck(self.pool4(enc4))
  77 +
  78 + upconv4 = self.upconv4(bottleneck)
  79 + dec4 = torch.cat((upconv4, enc4), dim=1)
  80 + dec4 = self.decoder4(dec4)
  81 +
  82 + upconv3 = self.upconv3(dec4)
  83 + dec3 = torch.cat((upconv3, enc3), dim=1)
  84 + dec3 = self.decoder3(dec3)
  85 +
  86 + upconv2 = self.upconv2(dec3)
  87 + dec2 = torch.cat((upconv2, enc2), dim=1)
  88 + dec2 = self.decoder2(dec2)
  89 +
  90 + upconv1 = self.upconv1(dec2)
  91 + dec1 = torch.cat((upconv1, enc1), dim=1)
  92 + dec1 = self.decoder1(dec1)
  93 +
  94 + conv = self.conv(dec1)
  95 +
  96 + sigmoid = torch.sigmoid(conv)
  97 +
  98 + return sigmoid
  99 +
  100 + def _block(self, in_channels, features, padding=1, kernel_size=5, name="block"):
  101 + return nn.Sequential(
  102 + OrderedDict(
  103 + (
  104 + (
  105 + f"{name}_conv1",
  106 + nn.Conv3d(
  107 + in_channels=in_channels,
  108 + out_channels=features,
  109 + kernel_size=kernel_size,
  110 + padding=padding,
  111 + bias=True,
  112 + ),
  113 + ),
  114 + (f"{name}_norm1", nn.BatchNorm3d(num_features=features)),
  115 + (f"{name}_relu1", nn.ReLU(inplace=True)),
  116 + (
  117 + f"{name}_conv2",
  118 + nn.Conv3d(
  119 + in_channels=features,
  120 + out_channels=features,
  121 + kernel_size=kernel_size,
  122 + padding=padding,
  123 + bias=True,
  124 + ),
  125 + ),
  126 + (f"{name}_norm2", nn.BatchNorm3d(num_features=features)),
  127 + (f"{name}_relu2", nn.ReLU(inplace=True)),
  128 + )
  129 + )
  130 + )
  131 +
  132 +
  133 +def main():
  134 + import torchviz
  135 + dev = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
  136 + model = Unet3D()
  137 + model.to(dev)
  138 + model.eval()
  139 + print(next(model.parameters()).is_cuda) # True
  140 + img = torch.randn(1, SIZE, SIZE, SIZE, 1).to(dev)
  141 + out = model(img)
  142 + dot = torchviz.make_dot(out, params=dict(model.named_parameters()), show_attrs=True, show_saved=True)
  143 + dot.render("unet", format="png")
  144 + torch.save(model, "model.pth")
  145 + print(dot)
  146 +
  147 +
  148 +if __name__ == "__main__":
  149 + main()
... ...
invesalius/segmentation/brain/segment.py
... ... @@ -13,6 +13,8 @@ import invesalius.data.slice_ as slc
13 13 from invesalius import inv_paths
14 14 from invesalius.data import imagedata_utils
15 15 from invesalius.utils import new_name_by_pattern
  16 +from invesalius.net.utils import download_url_to_file
  17 +from invesalius import inv_paths
16 18  
17 19 from . import utils
18 20  
... ... @@ -64,6 +66,17 @@ def predict_patch(sub_image, patch, nn_model, patch_size=SIZE):
64 66 0 : ez - iz, 0 : ey - iy, 0 : ex - ix
65 67 ]
66 68  
  69 +def predict_patch_torch(sub_image, patch, nn_model, device, patch_size=SIZE):
  70 + import torch
  71 + with torch.no_grad():
  72 + (iz, ez), (iy, ey), (ix, ex) = patch
  73 + sub_mask = nn_model(
  74 + torch.from_numpy(sub_image.reshape(1, 1, patch_size, patch_size, patch_size)).to(device)
  75 + ).cpu().numpy()
  76 + return sub_mask.reshape(patch_size, patch_size, patch_size)[
  77 + 0 : ez - iz, 0 : ey - iy, 0 : ex - ix
  78 + ]
  79 +
67 80  
68 81 def brain_segment(image, probability_array, comm_array):
69 82 import keras
... ... @@ -89,6 +102,42 @@ def brain_segment(image, probability_array, comm_array):
89 102 comm_array[0] = np.Inf
90 103  
91 104  
  105 +def download_callback(comm_array):
  106 + def _download_callback(value):
  107 + comm_array[0] = value
  108 + return _download_callback
  109 +
  110 +def brain_segment_torch(image, device_id, probability_array, comm_array):
  111 + import torch
  112 + from .model import Unet3D
  113 + device = torch.device(device_id)
  114 + state_dict_file = inv_paths.USER_DL_WEIGHTS.joinpath("brain_mri_t1.pt")
  115 + if not state_dict_file.exists():
  116 + download_url_to_file(
  117 + "https://github.com/tfmoraes/deepbrain_torch/releases/download/v1.1.0/weights.pt",
  118 + state_dict_file,
  119 + "194b0305947c9326eeee9da34ada728435a13c7b24015cbd95971097fc178f22",
  120 + download_callback(comm_array)
  121 + )
  122 + state_dict = torch.load(str(state_dict_file))
  123 + model = Unet3D()
  124 + model.load_state_dict(state_dict["model_state_dict"])
  125 + model.to(device)
  126 + model.eval()
  127 +
  128 + image = imagedata_utils.image_normalize(image, 0.0, 1.0, output_dtype=np.float32)
  129 + sums = np.zeros_like(image)
  130 + # segmenting by patches
  131 + for completion, sub_image, patch in gen_patches(image, SIZE, OVERLAP):
  132 + comm_array[0] = completion
  133 + (iz, ez), (iy, ey), (ix, ex) = patch
  134 + sub_mask = predict_patch_torch(sub_image, patch, model, device, SIZE)
  135 + probability_array[iz:ez, iy:ey, ix:ex] += sub_mask
  136 + sums[iz:ez, iy:ey, ix:ex] += 1
  137 +
  138 + probability_array /= sums
  139 + comm_array[0] = np.Inf
  140 +
92 141 ctx = multiprocessing.get_context('spawn')
93 142 class SegmentProcess(ctx.Process):
94 143 def __init__(self, image, create_new_mask, backend, device_id, use_gpu, apply_wwwl=False, window_width=255, window_level=127):
... ... @@ -138,8 +187,7 @@ class SegmentProcess(ctx.Process):
138 187 mode="r",
139 188 )
140 189  
141   - print(image.min(), image.max())
142   - if self.apply_segment_threshold:
  190 + if self.apply_wwwl:
143 191 print("Applying window level")
144 192 image = get_LUT_value(image, self.window_width, self.window_level)
145 193  
... ... @@ -153,8 +201,11 @@ class SegmentProcess(ctx.Process):
153 201 self._comm_array_filename, dtype=np.float32, shape=(1,), mode="r+"
154 202 )
155 203  
156   - utils.prepare_ambient(self.backend, self.device_id, self.use_gpu)
157   - brain_segment(image, probability_array, comm_array)
  204 + if self.backend.lower() == "pytorch":
  205 + brain_segment_torch(image, self.device_id, probability_array, comm_array)
  206 + else:
  207 + utils.prepare_ambient(self.backend, self.device_id, self.use_gpu)
  208 + brain_segment(image, probability_array, comm_array)
158 209  
159 210 @property
160 211 def exception(self):
... ...
optional-requirements.txt
  1 +aioconsole==0.3.2
1 2 mido==1.2.10
  3 +nest-asyncio==1.5.1
2 4 python-rtmidi==1.4.9
3 5 python-socketio[client]==5.3.0
  6 +requests==2.26.0
  7 +uvicorn[standard]==0.15.0
... ...
requirements.txt
... ... @@ -15,3 +15,4 @@ scipy==1.7.1
15 15 vtk==9.0.3
16 16 wxPython==4.1.1
17 17 Theano==1.0.5
  18 +torch==1.9.1
... ...
scripts/invesalius_server.py 0 → 100644
... ... @@ -0,0 +1,88 @@
  1 +#!/usr/bin/env python3
  2 +# -*- coding: utf-8 -*-
  3 +
  4 +# This scripts allows sending events to InVesalius via Socket.IO, mimicking InVesalius's
  5 +# internal communication. It can be useful for developing and debugging InVesalius.
  6 +#
  7 +# Example usage:
  8 +#
  9 +# - (In console window 1) Run the script by: python scripts/invesalius_server.py 5000
  10 +#
  11 +# - (In console window 2) Run InVesalius by: python app.py --remote-host http://localhost:5000
  12 +#
  13 +# - If InVesalius connected to the server successfully, a message should appear in console window 1,
  14 +# asking to provide the topic name.
  15 +#
  16 +# - Enter the topic name, such as "Add marker" (without quotes).
  17 +#
  18 +# - Enter the data, such as {"ball_id": 0, "size": 2, "colour": [1.0, 1.0, 0.0], "coord": [10.0, 20.0, 30.0]}
  19 +#
  20 +# - If successful, a message should now appear in console window 2, indicating that the event was received.
  21 +
  22 +import asyncio
  23 +import sys
  24 +import json
  25 +
  26 +import aioconsole
  27 +import nest_asyncio
  28 +import socketio
  29 +import uvicorn
  30 +
  31 +nest_asyncio.apply()
  32 +
  33 +if len(sys.argv) != 2:
  34 + print ("""This script allows sending events to InVesalius.
  35 +
  36 +Usage: python invesalius_server.py port""")
  37 + sys.exit(1)
  38 +
  39 +port = int(sys.argv[1])
  40 +
  41 +sio = socketio.AsyncServer(async_mode='asgi')
  42 +app = socketio.ASGIApp(sio)
  43 +
  44 +connected = False
  45 +
  46 +@sio.event
  47 +def connect(sid, environ):
  48 + global connected
  49 + connected = True
  50 +
  51 +def print_json_error(e):
  52 + print("Invalid JSON")
  53 + print(e.doc)
  54 + print(" " * e.pos + "^")
  55 + print(e.msg)
  56 + print("")
  57 +
  58 +async def run():
  59 + while True:
  60 + if not connected:
  61 + await asyncio.sleep(1)
  62 + continue
  63 +
  64 + print("Enter topic: ")
  65 + topic = await aioconsole.ainput()
  66 + print("Enter data as JSON: ")
  67 + data = await aioconsole.ainput()
  68 +
  69 + try:
  70 + decoded = json.loads(data)
  71 + except json.decoder.JSONDecodeError as e:
  72 + print_json_error(e)
  73 + continue
  74 +
  75 + await sio.emit(
  76 + event="to_neuronavigation",
  77 + data={
  78 + "topic": topic,
  79 + "data": decoded,
  80 + }
  81 + )
  82 +
  83 +async def main():
  84 + asyncio.create_task(run())
  85 + uvicorn.run(app, port=port, host='0.0.0.0', loop='asyncio')
  86 +
  87 +if __name__ == '__main__':
  88 + asyncio.run(main(), debug=True)
... ...