Commit b6ae340436ae6e2556de9f1521b7ffc029cac8a6

Authored by Renan
Committed by GitHub
2 parents 75cc013e 2a9da902
Exists in master

Merge branch 'master' into multimodal_tracking

invesalius/constants.py
@@ -830,8 +830,14 @@ TREKKER_CONFIG = {'seed_max': 1, 'step_size': 0.1, 'min_fod': 0.1, 'probe_qualit @@ -830,8 +830,14 @@ TREKKER_CONFIG = {'seed_max': 1, 'step_size': 0.1, 'min_fod': 0.1, 'probe_qualit
830 830
831 MARKER_FILE_MAGICK_STRING = "INVESALIUS3_MARKER_FILE_" 831 MARKER_FILE_MAGICK_STRING = "INVESALIUS3_MARKER_FILE_"
832 CURRENT_MARKER_FILE_VERSION = 0 832 CURRENT_MARKER_FILE_VERSION = 0
  833 +
833 WILDCARD_MARKER_FILES = _("Marker scanner coord files (*.mkss)|*.mkss") 834 WILDCARD_MARKER_FILES = _("Marker scanner coord files (*.mkss)|*.mkss")
834 835
  836 +# Serial port
  837 +BAUD_RATES = [300, 1200, 2400, 4800, 9600, 19200, 38400, 57600, 115200]
  838 +BAUD_RATE_DEFAULT_SELECTION = 4
  839 +
  840 +#Robot
835 ROBOT_ElFIN_IP = ['Select robot IP:', '143.107.220.251', '169.254.153.251', '127.0.0.1'] 841 ROBOT_ElFIN_IP = ['Select robot IP:', '143.107.220.251', '169.254.153.251', '127.0.0.1']
836 ROBOT_ElFIN_PORT = 10003 842 ROBOT_ElFIN_PORT = 10003
837 843
invesalius/data/serial_port_connection.py
@@ -28,7 +28,7 @@ from invesalius.pubsub import pub as Publisher @@ -28,7 +28,7 @@ from invesalius.pubsub import pub as Publisher
28 class SerialPortConnection(threading.Thread): 28 class SerialPortConnection(threading.Thread):
29 BINARY_PULSE = b'\x01' 29 BINARY_PULSE = b'\x01'
30 30
31 - def __init__(self, port, serial_port_queue, event, sleep_nav): 31 + def __init__(self, com_port, baud_rate, serial_port_queue, event, sleep_nav):
32 """ 32 """
33 Thread created to communicate using the serial port to interact with software during neuronavigation. 33 Thread created to communicate using the serial port to interact with software during neuronavigation.
34 """ 34 """
@@ -37,28 +37,29 @@ class SerialPortConnection(threading.Thread): @@ -37,28 +37,29 @@ class SerialPortConnection(threading.Thread):
37 self.connection = None 37 self.connection = None
38 self.stylusplh = False 38 self.stylusplh = False
39 39
40 - self.port = port 40 + self.com_port = com_port
  41 + self.baud_rate = baud_rate
41 self.serial_port_queue = serial_port_queue 42 self.serial_port_queue = serial_port_queue
42 self.event = event 43 self.event = event
43 self.sleep_nav = sleep_nav 44 self.sleep_nav = sleep_nav
44 45
45 def Connect(self): 46 def Connect(self):
46 - if self.port is None: 47 + if self.com_port is None:
47 print("Serial port init error: COM port is unset.") 48 print("Serial port init error: COM port is unset.")
48 return 49 return
49 try: 50 try:
50 import serial 51 import serial
51 - self.connection = serial.Serial(self.port, baudrate=115200, timeout=0)  
52 - print("Connection to port {} opened.".format(self.port)) 52 + self.connection = serial.Serial(self.com_port, baudrate=self.baud_rate, timeout=0)
  53 + print("Connection to port {} opened.".format(self.com_port))
53 54
54 Publisher.sendMessage('Serial port connection', state=True) 55 Publisher.sendMessage('Serial port connection', state=True)
55 except: 56 except:
56 - print("Serial port init error: Connecting to port {} failed.".format(self.port)) 57 + print("Serial port init error: Connecting to port {} failed.".format(self.com_port))
57 58
58 def Disconnect(self): 59 def Disconnect(self):
59 if self.connection: 60 if self.connection:
60 self.connection.close() 61 self.connection.close()
61 - print("Connection to port {} closed.".format(self.port)) 62 + print("Connection to port {} closed.".format(self.com_port))
62 63
63 Publisher.sendMessage('Serial port connection', state=False) 64 Publisher.sendMessage('Serial port connection', state=False)
64 65
@@ -74,12 +75,11 @@ class SerialPortConnection(threading.Thread): @@ -74,12 +75,11 @@ class SerialPortConnection(threading.Thread):
74 trigger_on = False 75 trigger_on = False
75 try: 76 try:
76 lines = self.connection.readlines() 77 lines = self.connection.readlines()
  78 + if lines:
  79 + trigger_on = True
77 except: 80 except:
78 print("Error: Serial port could not be read.") 81 print("Error: Serial port could not be read.")
79 82
80 - if lines:  
81 - trigger_on = True  
82 -  
83 if self.stylusplh: 83 if self.stylusplh:
84 trigger_on = True 84 trigger_on = True
85 self.stylusplh = False 85 self.stylusplh = False
invesalius/data/trackers.py
@@ -299,7 +299,7 @@ def PlhSerialConnection(tracker_id): @@ -299,7 +299,7 @@ def PlhSerialConnection(tracker_id):
299 import serial 299 import serial
300 from wx import ID_OK 300 from wx import ID_OK
301 trck_init = None 301 trck_init = None
302 - dlg_port = dlg.SetCOMport() 302 + dlg_port = dlg.SetCOMPort(select_baud_rate=False)
303 if dlg_port.ShowModal() == ID_OK: 303 if dlg_port.ShowModal() == ID_OK:
304 com_port = dlg_port.GetValue() 304 com_port = dlg_port.GetValue()
305 try: 305 try:
invesalius/gui/brain_seg_dialog.py
@@ -22,6 +22,22 @@ HAS_THEANO = bool(importlib.util.find_spec("theano")) @@ -22,6 +22,22 @@ HAS_THEANO = bool(importlib.util.find_spec("theano"))
22 HAS_PLAIDML = bool(importlib.util.find_spec("plaidml")) 22 HAS_PLAIDML = bool(importlib.util.find_spec("plaidml"))
23 PLAIDML_DEVICES = {} 23 PLAIDML_DEVICES = {}
24 24
  25 +try:
  26 + import torch
  27 + HAS_TORCH = True
  28 +except ImportError:
  29 + HAS_TORCH = False
  30 +
  31 +if HAS_TORCH:
  32 + TORCH_DEVICES = {}
  33 + if torch.cuda.is_available():
  34 + for i in range(torch.cuda.device_count()):
  35 + name = torch.cuda.get_device_name()
  36 + device_id = f'cuda:{i}'
  37 + TORCH_DEVICES[name] = device_id
  38 + TORCH_DEVICES['CPU'] = 'cpu'
  39 +
  40 +
25 41
26 if HAS_PLAIDML: 42 if HAS_PLAIDML:
27 with multiprocessing.Pool(1) as p: 43 with multiprocessing.Pool(1) as p:
@@ -43,12 +59,15 @@ class BrainSegmenterDialog(wx.Dialog): @@ -43,12 +59,15 @@ class BrainSegmenterDialog(wx.Dialog):
43 style=wx.DEFAULT_DIALOG_STYLE | wx.FRAME_FLOAT_ON_PARENT, 59 style=wx.DEFAULT_DIALOG_STYLE | wx.FRAME_FLOAT_ON_PARENT,
44 ) 60 )
45 backends = [] 61 backends = []
  62 + if HAS_TORCH:
  63 + backends.append("Pytorch")
46 if HAS_PLAIDML: 64 if HAS_PLAIDML:
47 backends.append("PlaidML") 65 backends.append("PlaidML")
48 if HAS_THEANO: 66 if HAS_THEANO:
49 backends.append("Theano") 67 backends.append("Theano")
50 # self.segmenter = segment.BrainSegmenter() 68 # self.segmenter = segment.BrainSegmenter()
51 # self.pg_dialog = None 69 # self.pg_dialog = None
  70 + self.torch_devices = TORCH_DEVICES
52 self.plaidml_devices = PLAIDML_DEVICES 71 self.plaidml_devices = PLAIDML_DEVICES
53 72
54 self.ps = None 73 self.ps = None
@@ -65,13 +84,19 @@ class BrainSegmenterDialog(wx.Dialog): @@ -65,13 +84,19 @@ class BrainSegmenterDialog(wx.Dialog):
65 w, h = self.CalcSizeFromTextSize("MM" * (1 + max(len(i) for i in backends))) 84 w, h = self.CalcSizeFromTextSize("MM" * (1 + max(len(i) for i in backends)))
66 self.cb_backends.SetMinClientSize((w, -1)) 85 self.cb_backends.SetMinClientSize((w, -1))
67 self.chk_use_gpu = wx.CheckBox(self, wx.ID_ANY, _("Use GPU")) 86 self.chk_use_gpu = wx.CheckBox(self, wx.ID_ANY, _("Use GPU"))
68 - if HAS_PLAIDML: 87 + if HAS_TORCH or HAS_PLAIDML:
  88 + if HAS_TORCH:
  89 + choices = list(self.torch_devices.keys())
  90 + value = choices[0]
  91 + else:
  92 + choices = list(self.plaidml_devices.keys())
  93 + value = choices[0]
69 self.lbl_device = wx.StaticText(self, -1, _("Device")) 94 self.lbl_device = wx.StaticText(self, -1, _("Device"))
70 self.cb_devices = wx.ComboBox( 95 self.cb_devices = wx.ComboBox(
71 self, 96 self,
72 wx.ID_ANY, 97 wx.ID_ANY,
73 - choices=list(self.plaidml_devices.keys()),  
74 - value=list(self.plaidml_devices.keys())[0], 98 + choices=choices,
  99 + value=value,
75 style=wx.CB_DROPDOWN | wx.CB_READONLY, 100 style=wx.CB_DROPDOWN | wx.CB_READONLY,
76 ) 101 )
77 self.sld_threshold = wx.Slider(self, wx.ID_ANY, 75, 0, 100) 102 self.sld_threshold = wx.Slider(self, wx.ID_ANY, 75, 0, 100)
@@ -109,7 +134,7 @@ class BrainSegmenterDialog(wx.Dialog): @@ -109,7 +134,7 @@ class BrainSegmenterDialog(wx.Dialog):
109 main_sizer.Add(sizer_backends, 0, wx.ALL | wx.EXPAND, 5) 134 main_sizer.Add(sizer_backends, 0, wx.ALL | wx.EXPAND, 5)
110 main_sizer.Add(self.chk_use_gpu, 0, wx.ALL, 5) 135 main_sizer.Add(self.chk_use_gpu, 0, wx.ALL, 5)
111 sizer_devices = wx.BoxSizer(wx.HORIZONTAL) 136 sizer_devices = wx.BoxSizer(wx.HORIZONTAL)
112 - if HAS_PLAIDML: 137 + if HAS_TORCH or HAS_PLAIDML:
113 sizer_devices.Add(self.lbl_device, 0, wx.ALIGN_CENTER, 0) 138 sizer_devices.Add(self.lbl_device, 0, wx.ALIGN_CENTER, 0)
114 sizer_devices.Add(self.cb_devices, 1, wx.LEFT, 5) 139 sizer_devices.Add(self.cb_devices, 1, wx.LEFT, 5)
115 main_sizer.Add(sizer_devices, 0, wx.ALL | wx.EXPAND, 5) 140 main_sizer.Add(sizer_devices, 0, wx.ALL | wx.EXPAND, 5)
@@ -177,8 +202,21 @@ class BrainSegmenterDialog(wx.Dialog): @@ -177,8 +202,21 @@ class BrainSegmenterDialog(wx.Dialog):
177 return width, height 202 return width, height
178 203
179 def OnSetBackend(self, evt=None): 204 def OnSetBackend(self, evt=None):
180 - if self.cb_backends.GetValue().lower() == "plaidml": 205 + if self.cb_backends.GetValue().lower() == "pytorch":
  206 + if HAS_TORCH:
  207 + choices = list(self.torch_devices.keys())
  208 + self.cb_devices.Clear()
  209 + self.cb_devices.SetItems(choices)
  210 + self.cb_devices.SetValue(choices[0])
  211 + self.lbl_device.Show()
  212 + self.cb_devices.Show()
  213 + self.chk_use_gpu.Hide()
  214 + elif self.cb_backends.GetValue().lower() == "plaidml":
181 if HAS_PLAIDML: 215 if HAS_PLAIDML:
  216 + choices = list(self.plaidml_devices.keys())
  217 + self.cb_devices.Clear()
  218 + self.cb_devices.SetItems(choices)
  219 + self.cb_devices.SetValue(choices[0])
182 self.lbl_device.Show() 220 self.lbl_device.Show()
183 self.cb_devices.Show() 221 self.cb_devices.Show()
184 self.chk_use_gpu.Hide() 222 self.chk_use_gpu.Hide()
@@ -216,10 +254,16 @@ class BrainSegmenterDialog(wx.Dialog): @@ -216,10 +254,16 @@ class BrainSegmenterDialog(wx.Dialog):
216 self.elapsed_time_timer.Start(1000) 254 self.elapsed_time_timer.Start(1000)
217 image = slc.Slice().matrix 255 image = slc.Slice().matrix
218 backend = self.cb_backends.GetValue() 256 backend = self.cb_backends.GetValue()
219 - try:  
220 - device_id = self.plaidml_devices[self.cb_devices.GetValue()]  
221 - except (KeyError, AttributeError):  
222 - device_id = "llvm_cpu.0" 257 + if backend.lower() == "pytorch":
  258 + try:
  259 + device_id = self.torch_devices[self.cb_devices.GetValue()]
  260 + except (KeyError, AttributeError):
  261 + device_id = "cpu"
  262 + else:
  263 + try:
  264 + device_id = self.plaidml_devices[self.cb_devices.GetValue()]
  265 + except (KeyError, AttributeError):
  266 + device_id = "llvm_cpu.0"
223 apply_wwwl = self.chk_apply_wwwl.GetValue() 267 apply_wwwl = self.chk_apply_wwwl.GetValue()
224 create_new_mask = self.chk_new_mask.GetValue() 268 create_new_mask = self.chk_new_mask.GetValue()
225 use_gpu = self.chk_use_gpu.GetValue() 269 use_gpu = self.chk_use_gpu.GetValue()
invesalius/gui/dialogs.py
@@ -4632,7 +4632,8 @@ class SetNDIconfigs(wx.Dialog): @@ -4632,7 +4632,8 @@ class SetNDIconfigs(wx.Dialog):
4632 self._init_gui() 4632 self._init_gui()
4633 4633
4634 def serial_ports(self): 4634 def serial_ports(self):
4635 - """ Lists serial port names and pre-select the description containing NDI 4635 + """
  4636 + Lists serial port names and pre-select the description containing NDI
4636 """ 4637 """
4637 import serial.tools.list_ports 4638 import serial.tools.list_ports
4638 4639
@@ -4748,13 +4749,16 @@ class SetNDIconfigs(wx.Dialog): @@ -4748,13 +4749,16 @@ class SetNDIconfigs(wx.Dialog):
4748 return self.com_ports.GetString(self.com_ports.GetSelection()).encode(const.FS_ENCODE), fn_probe, fn_ref, fn_obj 4749 return self.com_ports.GetString(self.com_ports.GetSelection()).encode(const.FS_ENCODE), fn_probe, fn_ref, fn_obj
4749 4750
4750 4751
4751 -class SetCOMport(wx.Dialog):  
4752 - def __init__(self, title=_("Select COM port")):  
4753 - wx.Dialog.__init__(self, wx.GetApp().GetTopWindow(), -1, title, style=wx.DEFAULT_DIALOG_STYLE|wx.FRAME_FLOAT_ON_PARENT|wx.STAY_ON_TOP) 4752 +class SetCOMPort(wx.Dialog):
  4753 + def __init__(self, select_baud_rate, title=_("Select COM port")):
  4754 + wx.Dialog.__init__(self, wx.GetApp().GetTopWindow(), -1, title, style=wx.DEFAULT_DIALOG_STYLE | wx.FRAME_FLOAT_ON_PARENT | wx.STAY_ON_TOP)
  4755 +
  4756 + self.select_baud_rate = select_baud_rate
4754 self._init_gui() 4757 self._init_gui()
4755 4758
4756 def serial_ports(self): 4759 def serial_ports(self):
4757 - """ Lists serial port names 4760 + """
  4761 + Lists serial port names
4758 """ 4762 """
4759 import serial.tools.list_ports 4763 import serial.tools.list_ports
4760 if sys.platform.startswith('win'): 4764 if sys.platform.startswith('win'):
@@ -4764,12 +4768,26 @@ class SetCOMport(wx.Dialog): @@ -4764,12 +4768,26 @@ class SetCOMport(wx.Dialog):
4764 return ports 4768 return ports
4765 4769
4766 def _init_gui(self): 4770 def _init_gui(self):
4767 - self.com_ports = wx.ComboBox(self, -1, style=wx.CB_DROPDOWN|wx.CB_READONLY) 4771 + # COM port selection
4768 ports = self.serial_ports() 4772 ports = self.serial_ports()
4769 - self.com_ports.Append(ports) 4773 + self.com_port_dropdown = wx.ComboBox(self, -1, choices=ports, style=wx.CB_DROPDOWN | wx.CB_READONLY)
  4774 + self.com_port_dropdown.SetSelection(0)
  4775 +
  4776 + com_port_text_and_dropdown = wx.BoxSizer(wx.VERTICAL)
  4777 + com_port_text_and_dropdown.Add(wx.StaticText(self, wx.ID_ANY, "COM port"), 0, wx.TOP | wx.RIGHT,5)
  4778 + com_port_text_and_dropdown.Add(self.com_port_dropdown, 0, wx.EXPAND)
  4779 +
  4780 + # Baud rate selection
  4781 + if self.select_baud_rate:
  4782 + baud_rates_as_strings = [str(baud_rate) for baud_rate in const.BAUD_RATES]
  4783 + self.baud_rate_dropdown = wx.ComboBox(self, -1, choices=baud_rates_as_strings, style=wx.CB_DROPDOWN | wx.CB_READONLY)
  4784 + self.baud_rate_dropdown.SetSelection(const.BAUD_RATE_DEFAULT_SELECTION)
4770 4785
4771 - # self.goto_orientation.SetSelection(cb_init) 4786 + baud_rate_text_and_dropdown = wx.BoxSizer(wx.VERTICAL)
  4787 + baud_rate_text_and_dropdown.Add(wx.StaticText(self, wx.ID_ANY, "Baud rate"), 0, wx.TOP | wx.RIGHT,5)
  4788 + baud_rate_text_and_dropdown.Add(self.baud_rate_dropdown, 0, wx.EXPAND)
4772 4789
  4790 + # OK and Cancel buttons
4773 btn_ok = wx.Button(self, wx.ID_OK) 4791 btn_ok = wx.Button(self, wx.ID_OK)
4774 btn_ok.SetHelpText("") 4792 btn_ok.SetHelpText("")
4775 btn_ok.SetDefault() 4793 btn_ok.SetDefault()
@@ -4782,10 +4800,16 @@ class SetCOMport(wx.Dialog): @@ -4782,10 +4800,16 @@ class SetCOMport(wx.Dialog):
4782 btnsizer.AddButton(btn_cancel) 4800 btnsizer.AddButton(btn_cancel)
4783 btnsizer.Realize() 4801 btnsizer.Realize()
4784 4802
  4803 + # Set up the main sizer
4785 main_sizer = wx.BoxSizer(wx.VERTICAL) 4804 main_sizer = wx.BoxSizer(wx.VERTICAL)
4786 4805
4787 main_sizer.Add((5, 5)) 4806 main_sizer.Add((5, 5))
4788 - main_sizer.Add(self.com_ports, 1, wx.EXPAND|wx.LEFT|wx.RIGHT, 5) 4807 + main_sizer.Add(com_port_text_and_dropdown, 1, wx.EXPAND | wx.LEFT | wx.RIGHT, 5)
  4808 +
  4809 + if self.select_baud_rate:
  4810 + main_sizer.Add((5, 5))
  4811 + main_sizer.Add(baud_rate_text_and_dropdown, 1, wx.EXPAND | wx.LEFT | wx.RIGHT, 5)
  4812 +
4789 main_sizer.Add((5, 5)) 4813 main_sizer.Add((5, 5))
4790 main_sizer.Add(btnsizer, 0, wx.EXPAND) 4814 main_sizer.Add(btnsizer, 0, wx.EXPAND)
4791 main_sizer.Add((5, 5)) 4815 main_sizer.Add((5, 5))
@@ -4796,7 +4820,14 @@ class SetCOMport(wx.Dialog): @@ -4796,7 +4820,14 @@ class SetCOMport(wx.Dialog):
4796 self.CenterOnParent() 4820 self.CenterOnParent()
4797 4821
4798 def GetValue(self): 4822 def GetValue(self):
4799 - return self.com_ports.GetString(self.com_ports.GetSelection()) 4823 + com_port = self.com_port_dropdown.GetString(self.com_port_dropdown.GetSelection())
  4824 +
  4825 + if self.select_baud_rate:
  4826 + baud_rate = self.baud_rate_dropdown.GetString(self.baud_rate_dropdown.GetSelection())
  4827 + else:
  4828 + baud_rate = None
  4829 +
  4830 + return com_port, baud_rate
4800 4831
4801 4832
4802 class ManualWWWLDialog(wx.Dialog): 4833 class ManualWWWLDialog(wx.Dialog):
invesalius/gui/task_navigator.py
@@ -234,8 +234,8 @@ class InnerFoldPanel(wx.Panel): @@ -234,8 +234,8 @@ class InnerFoldPanel(wx.Panel):
234 checkcamera.Bind(wx.EVT_CHECKBOX, self.OnVolumeCamera) 234 checkcamera.Bind(wx.EVT_CHECKBOX, self.OnVolumeCamera)
235 self.checkcamera = checkcamera 235 self.checkcamera = checkcamera
236 236
237 - # Check box to create markers from serial port  
238 - tooltip = wx.ToolTip(_("Enable serial port communication for creating markers")) 237 + # Check box to use serial port to trigger pulse signal and create markers
  238 + tooltip = wx.ToolTip(_("Enable serial port communication to trigger pulse and create markers"))
239 checkbox_serial_port = wx.CheckBox(self, -1, _('Serial port')) 239 checkbox_serial_port = wx.CheckBox(self, -1, _('Serial port'))
240 checkbox_serial_port.SetToolTip(tooltip) 240 checkbox_serial_port.SetToolTip(tooltip)
241 checkbox_serial_port.SetValue(False) 241 checkbox_serial_port.SetValue(False)
@@ -297,14 +297,20 @@ class InnerFoldPanel(wx.Panel): @@ -297,14 +297,20 @@ class InnerFoldPanel(wx.Panel):
297 self.checkobj.Enable(True) 297 self.checkobj.Enable(True)
298 298
299 def OnEnableSerialPort(self, evt, ctrl): 299 def OnEnableSerialPort(self, evt, ctrl):
300 - com_port = None  
301 if ctrl.GetValue(): 300 if ctrl.GetValue():
302 from wx import ID_OK 301 from wx import ID_OK
303 - dlg_port = dlg.SetCOMport()  
304 - if dlg_port.ShowModal() == ID_OK:  
305 - com_port = dlg_port.GetValue() 302 + dlg_port = dlg.SetCOMPort(select_baud_rate=False)
306 303
307 - Publisher.sendMessage('Update serial port', serial_port=com_port) 304 + if dlg_port.ShowModal() != ID_OK:
  305 + ctrl.SetValue(False)
  306 + return
  307 +
  308 + com_port = dlg_port.GetValue()
  309 + baud_rate = 115200
  310 +
  311 + Publisher.sendMessage('Update serial port', serial_port_in_use=True, com_port=com_port, baud_rate=baud_rate)
  312 + else:
  313 + Publisher.sendMessage('Update serial port', serial_port_in_use=False)
308 314
309 def OnShowObject(self, evt=None, flag=None, obj_name=None, polydata=None, use_default_object=True): 315 def OnShowObject(self, evt=None, flag=None, obj_name=None, polydata=None, use_default_object=True):
310 if not evt: 316 if not evt:
@@ -503,7 +509,6 @@ class NeuronavigationPanel(wx.Panel): @@ -503,7 +509,6 @@ class NeuronavigationPanel(wx.Panel):
503 Publisher.subscribe(self.LoadImageFiducials, 'Load image fiducials') 509 Publisher.subscribe(self.LoadImageFiducials, 'Load image fiducials')
504 Publisher.subscribe(self.SetImageFiducial, 'Set image fiducial') 510 Publisher.subscribe(self.SetImageFiducial, 'Set image fiducial')
505 Publisher.subscribe(self.SetTrackerFiducial, 'Set tracker fiducial') 511 Publisher.subscribe(self.SetTrackerFiducial, 'Set tracker fiducial')
506 - Publisher.subscribe(self.UpdateSerialPort, 'Update serial port')  
507 Publisher.subscribe(self.UpdateTrackObjectState, 'Update track object state') 512 Publisher.subscribe(self.UpdateTrackObjectState, 'Update track object state')
508 Publisher.subscribe(self.UpdateImageCoordinates, 'Set cross focal point') 513 Publisher.subscribe(self.UpdateImageCoordinates, 'Set cross focal point')
509 Publisher.subscribe(self.OnDisconnectTracker, 'Disconnect tracker') 514 Publisher.subscribe(self.OnDisconnectTracker, 'Disconnect tracker')
@@ -627,9 +632,6 @@ class NeuronavigationPanel(wx.Panel): @@ -627,9 +632,6 @@ class NeuronavigationPanel(wx.Panel):
627 def UpdateTrackObjectState(self, evt=None, flag=None, obj_name=None, polydata=None, use_default_object=True): 632 def UpdateTrackObjectState(self, evt=None, flag=None, obj_name=None, polydata=None, use_default_object=True):
628 self.navigation.track_obj = flag 633 self.navigation.track_obj = flag
629 634
630 - def UpdateSerialPort(self, serial_port):  
631 - self.navigation.serial_port = serial_port  
632 -  
633 def ResetICP(self): 635 def ResetICP(self):
634 self.icp.ResetICP() 636 self.icp.ResetICP()
635 self.checkbox_icp.Enable(False) 637 self.checkbox_icp.Enable(False)
invesalius/inv_paths.py
@@ -27,6 +27,7 @@ CONF_DIR = pathlib.Path(os.environ.get("XDG_CONFIG_HOME", USER_DIR.joinpath(".co @@ -27,6 +27,7 @@ CONF_DIR = pathlib.Path(os.environ.get("XDG_CONFIG_HOME", USER_DIR.joinpath(".co
27 USER_INV_DIR = CONF_DIR.joinpath("invesalius") 27 USER_INV_DIR = CONF_DIR.joinpath("invesalius")
28 USER_PRESET_DIR = USER_INV_DIR.joinpath("presets") 28 USER_PRESET_DIR = USER_INV_DIR.joinpath("presets")
29 USER_LOG_DIR = USER_INV_DIR.joinpath("logs") 29 USER_LOG_DIR = USER_INV_DIR.joinpath("logs")
  30 +USER_DL_WEIGHTS = USER_INV_DIR.joinpath("deep_learning/weights/")
30 USER_RAYCASTING_PRESETS_DIRECTORY = USER_PRESET_DIR.joinpath("raycasting") 31 USER_RAYCASTING_PRESETS_DIRECTORY = USER_PRESET_DIR.joinpath("raycasting")
31 TEMP_DIR = tempfile.gettempdir() 32 TEMP_DIR = tempfile.gettempdir()
32 33
@@ -97,6 +98,7 @@ def create_conf_folders(): @@ -97,6 +98,7 @@ def create_conf_folders():
97 USER_INV_DIR.mkdir(parents=True, exist_ok=True) 98 USER_INV_DIR.mkdir(parents=True, exist_ok=True)
98 USER_PRESET_DIR.mkdir(parents=True, exist_ok=True) 99 USER_PRESET_DIR.mkdir(parents=True, exist_ok=True)
99 USER_LOG_DIR.mkdir(parents=True, exist_ok=True) 100 USER_LOG_DIR.mkdir(parents=True, exist_ok=True)
  101 + USER_DL_WEIGHTS.mkdir(parents=True, exist_ok=True)
100 USER_PLUGINS_DIRECTORY.mkdir(parents=True, exist_ok=True) 102 USER_PLUGINS_DIRECTORY.mkdir(parents=True, exist_ok=True)
101 103
102 104
invesalius/navigation/navigation.py
@@ -171,7 +171,9 @@ class Navigation(): @@ -171,7 +171,9 @@ class Navigation():
171 self.sleep_nav = const.SLEEP_NAVIGATION 171 self.sleep_nav = const.SLEEP_NAVIGATION
172 172
173 # Serial port 173 # Serial port
174 - self.serial_port = None 174 + self.serial_port_in_use = False
  175 + self.com_port = None
  176 + self.baud_rate = None
175 self.serial_port_connection = None 177 self.serial_port_connection = None
176 178
177 # During navigation 179 # During navigation
@@ -181,6 +183,7 @@ class Navigation(): @@ -181,6 +183,7 @@ class Navigation():
181 183
182 def __bind_events(self): 184 def __bind_events(self):
183 Publisher.subscribe(self.CoilAtTarget, 'Coil at target') 185 Publisher.subscribe(self.CoilAtTarget, 'Coil at target')
  186 + Publisher.subscribe(self.UpdateSerialPort, 'Update serial port')
184 187
185 def CoilAtTarget(self, state): 188 def CoilAtTarget(self, state):
186 self.coil_at_target = state 189 self.coil_at_target = state
@@ -189,8 +192,10 @@ class Navigation(): @@ -189,8 +192,10 @@ class Navigation():
189 self.sleep_nav = sleep 192 self.sleep_nav = sleep
190 self.serial_port_connection.sleep_nav = sleep 193 self.serial_port_connection.sleep_nav = sleep
191 194
192 - def SerialPortEnabled(self):  
193 - return self.serial_port is not None 195 + def UpdateSerialPort(self, serial_port_in_use, com_port=None, baud_rate=None):
  196 + self.serial_port_in_use = serial_port_in_use
  197 + self.com_port = com_port
  198 + self.baud_rate = baud_rate
194 199
195 def SetReferenceMode(self, value): 200 def SetReferenceMode(self, value):
196 self.ref_mode_id = value 201 self.ref_mode_id = value
@@ -218,7 +223,7 @@ class Navigation(): @@ -218,7 +223,7 @@ class Navigation():
218 return fre, fre <= const.FIDUCIAL_REGISTRATION_ERROR_THRESHOLD 223 return fre, fre <= const.FIDUCIAL_REGISTRATION_ERROR_THRESHOLD
219 224
220 def PedalStateChanged(self, state): 225 def PedalStateChanged(self, state):
221 - if state is True and self.coil_at_target and self.SerialPortEnabled(): 226 + if state is True and self.coil_at_target and self.serial_port_in_use:
222 self.serial_port_connection.SendPulse() 227 self.serial_port_connection.SendPulse()
223 228
224 def StartNavigation(self, tracker): 229 def StartNavigation(self, tracker):
@@ -230,7 +235,7 @@ class Navigation(): @@ -230,7 +235,7 @@ class Navigation():
230 if self.event.is_set(): 235 if self.event.is_set():
231 self.event.clear() 236 self.event.clear()
232 237
233 - vis_components = [self.SerialPortEnabled(), self.view_tracts, self.peel_loaded] 238 + vis_components = [self.serial_port_in_use, self.view_tracts, self.peel_loaded]
234 vis_queues = [self.coord_queue, self.serial_port_queue, self.tracts_queue, self.icp_queue, self.robottarget_queue] 239 vis_queues = [self.coord_queue, self.serial_port_queue, self.tracts_queue, self.icp_queue, self.robottarget_queue]
235 240
236 Publisher.sendMessage("Navigation status", nav_status=True, vis_status=vis_components) 241 Publisher.sendMessage("Navigation status", nav_status=True, vis_status=vis_components)
@@ -279,12 +284,13 @@ class Navigation(): @@ -279,12 +284,13 @@ class Navigation():
279 284
280 if not errors: 285 if not errors:
281 #TODO: Test the serial port thread 286 #TODO: Test the serial port thread
282 - if self.SerialPortEnabled(): 287 + if self.serial_port_in_use:
283 self.serial_port_connection = spc.SerialPortConnection( 288 self.serial_port_connection = spc.SerialPortConnection(
284 - self.serial_port,  
285 - self.serial_port_queue,  
286 - self.event,  
287 - self.sleep_nav, 289 + com_port=self.com_port,
  290 + baud_rate=self.baud_rate,
  291 + serial_port_queue=self.serial_port_queue,
  292 + event=self.event,
  293 + sleep_nav=self.sleep_nav,
288 ) 294 )
289 self.serial_port_connection.Connect() 295 self.serial_port_connection.Connect()
290 jobs_list.append(self.serial_port_connection) 296 jobs_list.append(self.serial_port_connection)
@@ -330,7 +336,7 @@ class Navigation(): @@ -330,7 +336,7 @@ class Navigation():
330 if self.serial_port_connection is not None: 336 if self.serial_port_connection is not None:
331 self.serial_port_connection.join() 337 self.serial_port_connection.join()
332 338
333 - if self.SerialPortEnabled(): 339 + if self.serial_port_in_use:
334 self.serial_port_queue.clear() 340 self.serial_port_queue.clear()
335 self.serial_port_queue.join() 341 self.serial_port_queue.join()
336 342
@@ -341,5 +347,5 @@ class Navigation(): @@ -341,5 +347,5 @@ class Navigation():
341 self.tracts_queue.clear() 347 self.tracts_queue.clear()
342 self.tracts_queue.join() 348 self.tracts_queue.join()
343 349
344 - vis_components = [self.SerialPortEnabled(), self.view_tracts, self.peel_loaded] 350 + vis_components = [self.serial_port_in_use, self.view_tracts, self.peel_loaded]
345 Publisher.sendMessage("Navigation status", nav_status=False, vis_status=vis_components) 351 Publisher.sendMessage("Navigation status", nav_status=False, vis_status=vis_components)
invesalius/net/utils.py 0 → 100644
@@ -0,0 +1,48 @@ @@ -0,0 +1,48 @@
  1 +from urllib.error import HTTPError
  2 +from urllib.request import urlopen, Request
  3 +from urllib.parse import urlparse
  4 +import pathlib
  5 +import tempfile
  6 +import typing
  7 +import hashlib
  8 +import os
  9 +import shutil
  10 +
  11 +def download_url_to_file(url: str, dst: pathlib.Path, hash: str = None, callback: typing.Callable[[float], None] = None):
  12 + file_size = None
  13 + total_downloaded = 0
  14 + if hash is not None:
  15 + calc_hash = hashlib.sha256()
  16 + req = Request(url)
  17 + response = urlopen(req)
  18 + meta = response.info()
  19 + if hasattr(meta, "getheaders"):
  20 + content_length = meta.getheaders("Content-Length")
  21 + else:
  22 + content_length = meta.get_all("Content-Length")
  23 +
  24 + if content_length is not None and len(content_length) > 0:
  25 + file_size = int(content_length[0])
  26 + dst.parent.mkdir(parents=True, exist_ok=True)
  27 + f = tempfile.NamedTemporaryFile(delete=False, dir=dst.parent)
  28 + try:
  29 + while True:
  30 + buffer = response.read(8192)
  31 + if len(buffer) == 0:
  32 + break
  33 + total_downloaded += len(buffer)
  34 + f.write(buffer)
  35 + if hash:
  36 + calc_hash.update(buffer)
  37 + if callback is not None:
  38 + callback(100 * total_downloaded/file_size)
  39 + f.close()
  40 + if hash is not None:
  41 + digest = calc_hash.hexdigest()
  42 + if digest != hash:
  43 + raise RuntimeError(f'Invalid hash value (expected "{hash}", got "{digest}")')
  44 + shutil.move(f.name, dst)
  45 + finally:
  46 + f.close()
  47 + if os.path.exists(f.name):
  48 + os.remove(f.name)
invesalius/segmentation/brain/model.py 0 → 100644
@@ -0,0 +1,149 @@ @@ -0,0 +1,149 @@
  1 +from collections import OrderedDict
  2 +
  3 +import torch
  4 +import torch.nn as nn
  5 +
  6 +SIZE = 48
  7 +
  8 +class Unet3D(nn.Module):
  9 + # Based on https://github.com/mateuszbuda/brain-segmentation-pytorch/blob/master/unet.py
  10 + def __init__(self, in_channels=1, out_channels=1, init_features=8):
  11 + super().__init__()
  12 + features = init_features
  13 +
  14 + self.encoder1 = self._block(
  15 + in_channels, features=features, padding=2, name="enc1"
  16 + )
  17 + self.pool1 = nn.MaxPool3d(kernel_size=2, stride=2)
  18 +
  19 + self.encoder2 = self._block(
  20 + features, features=features * 2, padding=2, name="enc2"
  21 + )
  22 + self.pool2 = nn.MaxPool3d(kernel_size=2, stride=2)
  23 +
  24 + self.encoder3 = self._block(
  25 + features * 2, features=features * 4, padding=2, name="enc3"
  26 + )
  27 + self.pool3 = nn.MaxPool3d(kernel_size=2, stride=2)
  28 +
  29 + self.encoder4 = self._block(
  30 + features * 4, features=features * 8, padding=2, name="enc4"
  31 + )
  32 + self.pool4 = nn.MaxPool3d(kernel_size=2, stride=2)
  33 +
  34 + self.bottleneck = self._block(
  35 + features * 8, features=features * 16, padding=2, name="bottleneck"
  36 + )
  37 +
  38 + self.upconv4 = nn.ConvTranspose3d(
  39 + features * 16, features * 8, kernel_size=4, stride=2, padding=1
  40 + )
  41 + self.decoder4 = self._block(
  42 + features * 16, features=features * 8, padding=2, name="dec4"
  43 + )
  44 +
  45 + self.upconv3 = nn.ConvTranspose3d(
  46 + features * 8, features * 4, kernel_size=4, stride=2, padding=1
  47 + )
  48 + self.decoder3 = self._block(
  49 + features * 8, features=features * 4, padding=2, name="dec4"
  50 + )
  51 +
  52 + self.upconv2 = nn.ConvTranspose3d(
  53 + features * 4, features * 2, kernel_size=4, stride=2, padding=1
  54 + )
  55 + self.decoder2 = self._block(
  56 + features * 4, features=features * 2, padding=2, name="dec4"
  57 + )
  58 +
  59 + self.upconv1 = nn.ConvTranspose3d(
  60 + features * 2, features, kernel_size=4, stride=2, padding=1
  61 + )
  62 + self.decoder1 = self._block(
  63 + features * 2, features=features, padding=2, name="dec4"
  64 + )
  65 +
  66 + self.conv = nn.Conv3d(
  67 + in_channels=features, out_channels=out_channels, kernel_size=1
  68 + )
  69 +
  70 + def forward(self, img):
  71 + enc1 = self.encoder1(img)
  72 + enc2 = self.encoder2(self.pool1(enc1))
  73 + enc3 = self.encoder3(self.pool2(enc2))
  74 + enc4 = self.encoder4(self.pool3(enc3))
  75 +
  76 + bottleneck = self.bottleneck(self.pool4(enc4))
  77 +
  78 + upconv4 = self.upconv4(bottleneck)
  79 + dec4 = torch.cat((upconv4, enc4), dim=1)
  80 + dec4 = self.decoder4(dec4)
  81 +
  82 + upconv3 = self.upconv3(dec4)
  83 + dec3 = torch.cat((upconv3, enc3), dim=1)
  84 + dec3 = self.decoder3(dec3)
  85 +
  86 + upconv2 = self.upconv2(dec3)
  87 + dec2 = torch.cat((upconv2, enc2), dim=1)
  88 + dec2 = self.decoder2(dec2)
  89 +
  90 + upconv1 = self.upconv1(dec2)
  91 + dec1 = torch.cat((upconv1, enc1), dim=1)
  92 + dec1 = self.decoder1(dec1)
  93 +
  94 + conv = self.conv(dec1)
  95 +
  96 + sigmoid = torch.sigmoid(conv)
  97 +
  98 + return sigmoid
  99 +
  100 + def _block(self, in_channels, features, padding=1, kernel_size=5, name="block"):
  101 + return nn.Sequential(
  102 + OrderedDict(
  103 + (
  104 + (
  105 + f"{name}_conv1",
  106 + nn.Conv3d(
  107 + in_channels=in_channels,
  108 + out_channels=features,
  109 + kernel_size=kernel_size,
  110 + padding=padding,
  111 + bias=True,
  112 + ),
  113 + ),
  114 + (f"{name}_norm1", nn.BatchNorm3d(num_features=features)),
  115 + (f"{name}_relu1", nn.ReLU(inplace=True)),
  116 + (
  117 + f"{name}_conv2",
  118 + nn.Conv3d(
  119 + in_channels=features,
  120 + out_channels=features,
  121 + kernel_size=kernel_size,
  122 + padding=padding,
  123 + bias=True,
  124 + ),
  125 + ),
  126 + (f"{name}_norm2", nn.BatchNorm3d(num_features=features)),
  127 + (f"{name}_relu2", nn.ReLU(inplace=True)),
  128 + )
  129 + )
  130 + )
  131 +
  132 +
  133 +def main():
  134 + import torchviz
  135 + dev = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
  136 + model = Unet3D()
  137 + model.to(dev)
  138 + model.eval()
  139 + print(next(model.parameters()).is_cuda) # True
  140 + img = torch.randn(1, SIZE, SIZE, SIZE, 1).to(dev)
  141 + out = model(img)
  142 + dot = torchviz.make_dot(out, params=dict(model.named_parameters()), show_attrs=True, show_saved=True)
  143 + dot.render("unet", format="png")
  144 + torch.save(model, "model.pth")
  145 + print(dot)
  146 +
  147 +
  148 +if __name__ == "__main__":
  149 + main()
invesalius/segmentation/brain/segment.py
@@ -13,6 +13,8 @@ import invesalius.data.slice_ as slc @@ -13,6 +13,8 @@ import invesalius.data.slice_ as slc
13 from invesalius import inv_paths 13 from invesalius import inv_paths
14 from invesalius.data import imagedata_utils 14 from invesalius.data import imagedata_utils
15 from invesalius.utils import new_name_by_pattern 15 from invesalius.utils import new_name_by_pattern
  16 +from invesalius.net.utils import download_url_to_file
  17 +from invesalius import inv_paths
16 18
17 from . import utils 19 from . import utils
18 20
@@ -64,6 +66,17 @@ def predict_patch(sub_image, patch, nn_model, patch_size=SIZE): @@ -64,6 +66,17 @@ def predict_patch(sub_image, patch, nn_model, patch_size=SIZE):
64 0 : ez - iz, 0 : ey - iy, 0 : ex - ix 66 0 : ez - iz, 0 : ey - iy, 0 : ex - ix
65 ] 67 ]
66 68
  69 +def predict_patch_torch(sub_image, patch, nn_model, device, patch_size=SIZE):
  70 + import torch
  71 + with torch.no_grad():
  72 + (iz, ez), (iy, ey), (ix, ex) = patch
  73 + sub_mask = nn_model(
  74 + torch.from_numpy(sub_image.reshape(1, 1, patch_size, patch_size, patch_size)).to(device)
  75 + ).cpu().numpy()
  76 + return sub_mask.reshape(patch_size, patch_size, patch_size)[
  77 + 0 : ez - iz, 0 : ey - iy, 0 : ex - ix
  78 + ]
  79 +
67 80
68 def brain_segment(image, probability_array, comm_array): 81 def brain_segment(image, probability_array, comm_array):
69 import keras 82 import keras
@@ -89,6 +102,42 @@ def brain_segment(image, probability_array, comm_array): @@ -89,6 +102,42 @@ def brain_segment(image, probability_array, comm_array):
89 comm_array[0] = np.Inf 102 comm_array[0] = np.Inf
90 103
91 104
  105 +def download_callback(comm_array):
  106 + def _download_callback(value):
  107 + comm_array[0] = value
  108 + return _download_callback
  109 +
  110 +def brain_segment_torch(image, device_id, probability_array, comm_array):
  111 + import torch
  112 + from .model import Unet3D
  113 + device = torch.device(device_id)
  114 + state_dict_file = inv_paths.USER_DL_WEIGHTS.joinpath("brain_mri_t1.pt")
  115 + if not state_dict_file.exists():
  116 + download_url_to_file(
  117 + "https://github.com/tfmoraes/deepbrain_torch/releases/download/v1.1.0/weights.pt",
  118 + state_dict_file,
  119 + "194b0305947c9326eeee9da34ada728435a13c7b24015cbd95971097fc178f22",
  120 + download_callback(comm_array)
  121 + )
  122 + state_dict = torch.load(str(state_dict_file))
  123 + model = Unet3D()
  124 + model.load_state_dict(state_dict["model_state_dict"])
  125 + model.to(device)
  126 + model.eval()
  127 +
  128 + image = imagedata_utils.image_normalize(image, 0.0, 1.0, output_dtype=np.float32)
  129 + sums = np.zeros_like(image)
  130 + # segmenting by patches
  131 + for completion, sub_image, patch in gen_patches(image, SIZE, OVERLAP):
  132 + comm_array[0] = completion
  133 + (iz, ez), (iy, ey), (ix, ex) = patch
  134 + sub_mask = predict_patch_torch(sub_image, patch, model, device, SIZE)
  135 + probability_array[iz:ez, iy:ey, ix:ex] += sub_mask
  136 + sums[iz:ez, iy:ey, ix:ex] += 1
  137 +
  138 + probability_array /= sums
  139 + comm_array[0] = np.Inf
  140 +
92 ctx = multiprocessing.get_context('spawn') 141 ctx = multiprocessing.get_context('spawn')
93 class SegmentProcess(ctx.Process): 142 class SegmentProcess(ctx.Process):
94 def __init__(self, image, create_new_mask, backend, device_id, use_gpu, apply_wwwl=False, window_width=255, window_level=127): 143 def __init__(self, image, create_new_mask, backend, device_id, use_gpu, apply_wwwl=False, window_width=255, window_level=127):
@@ -138,8 +187,7 @@ class SegmentProcess(ctx.Process): @@ -138,8 +187,7 @@ class SegmentProcess(ctx.Process):
138 mode="r", 187 mode="r",
139 ) 188 )
140 189
141 - print(image.min(), image.max())  
142 - if self.apply_segment_threshold: 190 + if self.apply_wwwl:
143 print("Applying window level") 191 print("Applying window level")
144 image = get_LUT_value(image, self.window_width, self.window_level) 192 image = get_LUT_value(image, self.window_width, self.window_level)
145 193
@@ -153,8 +201,11 @@ class SegmentProcess(ctx.Process): @@ -153,8 +201,11 @@ class SegmentProcess(ctx.Process):
153 self._comm_array_filename, dtype=np.float32, shape=(1,), mode="r+" 201 self._comm_array_filename, dtype=np.float32, shape=(1,), mode="r+"
154 ) 202 )
155 203
156 - utils.prepare_ambient(self.backend, self.device_id, self.use_gpu)  
157 - brain_segment(image, probability_array, comm_array) 204 + if self.backend.lower() == "pytorch":
  205 + brain_segment_torch(image, self.device_id, probability_array, comm_array)
  206 + else:
  207 + utils.prepare_ambient(self.backend, self.device_id, self.use_gpu)
  208 + brain_segment(image, probability_array, comm_array)
158 209
159 @property 210 @property
160 def exception(self): 211 def exception(self):
optional-requirements.txt
  1 +aioconsole==0.3.2
1 mido==1.2.10 2 mido==1.2.10
  3 +nest-asyncio==1.5.1
2 python-rtmidi==1.4.9 4 python-rtmidi==1.4.9
3 python-socketio[client]==5.3.0 5 python-socketio[client]==5.3.0
  6 +requests==2.26.0
  7 +uvicorn[standard]==0.15.0
requirements.txt
@@ -15,3 +15,4 @@ scipy==1.7.1 @@ -15,3 +15,4 @@ scipy==1.7.1
15 vtk==9.0.3 15 vtk==9.0.3
16 wxPython==4.1.1 16 wxPython==4.1.1
17 Theano==1.0.5 17 Theano==1.0.5
  18 +torch==1.9.1
scripts/invesalius_server.py 0 → 100644
@@ -0,0 +1,88 @@ @@ -0,0 +1,88 @@
  1 +#!/usr/bin/env python3
  2 +# -*- coding: utf-8 -*-
  3 +
  4 +# This scripts allows sending events to InVesalius via Socket.IO, mimicking InVesalius's
  5 +# internal communication. It can be useful for developing and debugging InVesalius.
  6 +#
  7 +# Example usage:
  8 +#
  9 +# - (In console window 1) Run the script by: python scripts/invesalius_server.py 5000
  10 +#
  11 +# - (In console window 2) Run InVesalius by: python app.py --remote-host http://localhost:5000
  12 +#
  13 +# - If InVesalius connected to the server successfully, a message should appear in console window 1,
  14 +# asking to provide the topic name.
  15 +#
  16 +# - Enter the topic name, such as "Add marker" (without quotes).
  17 +#
  18 +# - Enter the data, such as {"ball_id": 0, "size": 2, "colour": [1.0, 1.0, 0.0], "coord": [10.0, 20.0, 30.0]}
  19 +#
  20 +# - If successful, a message should now appear in console window 2, indicating that the event was received.
  21 +
  22 +import asyncio
  23 +import sys
  24 +import json
  25 +
  26 +import aioconsole
  27 +import nest_asyncio
  28 +import socketio
  29 +import uvicorn
  30 +
  31 +nest_asyncio.apply()
  32 +
  33 +if len(sys.argv) != 2:
  34 + print ("""This script allows sending events to InVesalius.
  35 +
  36 +Usage: python invesalius_server.py port""")
  37 + sys.exit(1)
  38 +
  39 +port = int(sys.argv[1])
  40 +
  41 +sio = socketio.AsyncServer(async_mode='asgi')
  42 +app = socketio.ASGIApp(sio)
  43 +
  44 +connected = False
  45 +
  46 +@sio.event
  47 +def connect(sid, environ):
  48 + global connected
  49 + connected = True
  50 +
  51 +def print_json_error(e):
  52 + print("Invalid JSON")
  53 + print(e.doc)
  54 + print(" " * e.pos + "^")
  55 + print(e.msg)
  56 + print("")
  57 +
  58 +async def run():
  59 + while True:
  60 + if not connected:
  61 + await asyncio.sleep(1)
  62 + continue
  63 +
  64 + print("Enter topic: ")
  65 + topic = await aioconsole.ainput()
  66 + print("Enter data as JSON: ")
  67 + data = await aioconsole.ainput()
  68 +
  69 + try:
  70 + decoded = json.loads(data)
  71 + except json.decoder.JSONDecodeError as e:
  72 + print_json_error(e)
  73 + continue
  74 +
  75 + await sio.emit(
  76 + event="to_neuronavigation",
  77 + data={
  78 + "topic": topic,
  79 + "data": decoded,
  80 + }
  81 + )
  82 +
  83 +async def main():
  84 + asyncio.create_task(run())
  85 + uvicorn.run(app, port=port, host='0.0.0.0', loop='asyncio')
  86 +
  87 +if __name__ == '__main__':
  88 + asyncio.run(main(), debug=True)