Commit b6ae340436ae6e2556de9f1521b7ffc029cac8a6
Committed by
GitHub
Exists in
master
Merge branch 'master' into multimodal_tracking
Showing
14 changed files
with
489 additions
and
57 deletions
Show diff stats
invesalius/constants.py
... | ... | @@ -830,8 +830,14 @@ TREKKER_CONFIG = {'seed_max': 1, 'step_size': 0.1, 'min_fod': 0.1, 'probe_qualit |
830 | 830 | |
831 | 831 | MARKER_FILE_MAGICK_STRING = "INVESALIUS3_MARKER_FILE_" |
832 | 832 | CURRENT_MARKER_FILE_VERSION = 0 |
833 | + | |
833 | 834 | WILDCARD_MARKER_FILES = _("Marker scanner coord files (*.mkss)|*.mkss") |
834 | 835 | |
836 | +# Serial port | |
837 | +BAUD_RATES = [300, 1200, 2400, 4800, 9600, 19200, 38400, 57600, 115200] | |
838 | +BAUD_RATE_DEFAULT_SELECTION = 4 | |
839 | + | |
840 | +#Robot | |
835 | 841 | ROBOT_ElFIN_IP = ['Select robot IP:', '143.107.220.251', '169.254.153.251', '127.0.0.1'] |
836 | 842 | ROBOT_ElFIN_PORT = 10003 |
837 | 843 | ... | ... |
invesalius/data/serial_port_connection.py
... | ... | @@ -28,7 +28,7 @@ from invesalius.pubsub import pub as Publisher |
28 | 28 | class SerialPortConnection(threading.Thread): |
29 | 29 | BINARY_PULSE = b'\x01' |
30 | 30 | |
31 | - def __init__(self, port, serial_port_queue, event, sleep_nav): | |
31 | + def __init__(self, com_port, baud_rate, serial_port_queue, event, sleep_nav): | |
32 | 32 | """ |
33 | 33 | Thread created to communicate using the serial port to interact with software during neuronavigation. |
34 | 34 | """ |
... | ... | @@ -37,28 +37,29 @@ class SerialPortConnection(threading.Thread): |
37 | 37 | self.connection = None |
38 | 38 | self.stylusplh = False |
39 | 39 | |
40 | - self.port = port | |
40 | + self.com_port = com_port | |
41 | + self.baud_rate = baud_rate | |
41 | 42 | self.serial_port_queue = serial_port_queue |
42 | 43 | self.event = event |
43 | 44 | self.sleep_nav = sleep_nav |
44 | 45 | |
45 | 46 | def Connect(self): |
46 | - if self.port is None: | |
47 | + if self.com_port is None: | |
47 | 48 | print("Serial port init error: COM port is unset.") |
48 | 49 | return |
49 | 50 | try: |
50 | 51 | import serial |
51 | - self.connection = serial.Serial(self.port, baudrate=115200, timeout=0) | |
52 | - print("Connection to port {} opened.".format(self.port)) | |
52 | + self.connection = serial.Serial(self.com_port, baudrate=self.baud_rate, timeout=0) | |
53 | + print("Connection to port {} opened.".format(self.com_port)) | |
53 | 54 | |
54 | 55 | Publisher.sendMessage('Serial port connection', state=True) |
55 | 56 | except: |
56 | - print("Serial port init error: Connecting to port {} failed.".format(self.port)) | |
57 | + print("Serial port init error: Connecting to port {} failed.".format(self.com_port)) | |
57 | 58 | |
58 | 59 | def Disconnect(self): |
59 | 60 | if self.connection: |
60 | 61 | self.connection.close() |
61 | - print("Connection to port {} closed.".format(self.port)) | |
62 | + print("Connection to port {} closed.".format(self.com_port)) | |
62 | 63 | |
63 | 64 | Publisher.sendMessage('Serial port connection', state=False) |
64 | 65 | |
... | ... | @@ -74,12 +75,11 @@ class SerialPortConnection(threading.Thread): |
74 | 75 | trigger_on = False |
75 | 76 | try: |
76 | 77 | lines = self.connection.readlines() |
78 | + if lines: | |
79 | + trigger_on = True | |
77 | 80 | except: |
78 | 81 | print("Error: Serial port could not be read.") |
79 | 82 | |
80 | - if lines: | |
81 | - trigger_on = True | |
82 | - | |
83 | 83 | if self.stylusplh: |
84 | 84 | trigger_on = True |
85 | 85 | self.stylusplh = False | ... | ... |
invesalius/data/trackers.py
... | ... | @@ -299,7 +299,7 @@ def PlhSerialConnection(tracker_id): |
299 | 299 | import serial |
300 | 300 | from wx import ID_OK |
301 | 301 | trck_init = None |
302 | - dlg_port = dlg.SetCOMport() | |
302 | + dlg_port = dlg.SetCOMPort(select_baud_rate=False) | |
303 | 303 | if dlg_port.ShowModal() == ID_OK: |
304 | 304 | com_port = dlg_port.GetValue() |
305 | 305 | try: | ... | ... |
invesalius/gui/brain_seg_dialog.py
... | ... | @@ -22,6 +22,22 @@ HAS_THEANO = bool(importlib.util.find_spec("theano")) |
22 | 22 | HAS_PLAIDML = bool(importlib.util.find_spec("plaidml")) |
23 | 23 | PLAIDML_DEVICES = {} |
24 | 24 | |
25 | +try: | |
26 | + import torch | |
27 | + HAS_TORCH = True | |
28 | +except ImportError: | |
29 | + HAS_TORCH = False | |
30 | + | |
31 | +if HAS_TORCH: | |
32 | + TORCH_DEVICES = {} | |
33 | + if torch.cuda.is_available(): | |
34 | + for i in range(torch.cuda.device_count()): | |
35 | + name = torch.cuda.get_device_name() | |
36 | + device_id = f'cuda:{i}' | |
37 | + TORCH_DEVICES[name] = device_id | |
38 | + TORCH_DEVICES['CPU'] = 'cpu' | |
39 | + | |
40 | + | |
25 | 41 | |
26 | 42 | if HAS_PLAIDML: |
27 | 43 | with multiprocessing.Pool(1) as p: |
... | ... | @@ -43,12 +59,15 @@ class BrainSegmenterDialog(wx.Dialog): |
43 | 59 | style=wx.DEFAULT_DIALOG_STYLE | wx.FRAME_FLOAT_ON_PARENT, |
44 | 60 | ) |
45 | 61 | backends = [] |
62 | + if HAS_TORCH: | |
63 | + backends.append("Pytorch") | |
46 | 64 | if HAS_PLAIDML: |
47 | 65 | backends.append("PlaidML") |
48 | 66 | if HAS_THEANO: |
49 | 67 | backends.append("Theano") |
50 | 68 | # self.segmenter = segment.BrainSegmenter() |
51 | 69 | # self.pg_dialog = None |
70 | + self.torch_devices = TORCH_DEVICES | |
52 | 71 | self.plaidml_devices = PLAIDML_DEVICES |
53 | 72 | |
54 | 73 | self.ps = None |
... | ... | @@ -65,13 +84,19 @@ class BrainSegmenterDialog(wx.Dialog): |
65 | 84 | w, h = self.CalcSizeFromTextSize("MM" * (1 + max(len(i) for i in backends))) |
66 | 85 | self.cb_backends.SetMinClientSize((w, -1)) |
67 | 86 | self.chk_use_gpu = wx.CheckBox(self, wx.ID_ANY, _("Use GPU")) |
68 | - if HAS_PLAIDML: | |
87 | + if HAS_TORCH or HAS_PLAIDML: | |
88 | + if HAS_TORCH: | |
89 | + choices = list(self.torch_devices.keys()) | |
90 | + value = choices[0] | |
91 | + else: | |
92 | + choices = list(self.plaidml_devices.keys()) | |
93 | + value = choices[0] | |
69 | 94 | self.lbl_device = wx.StaticText(self, -1, _("Device")) |
70 | 95 | self.cb_devices = wx.ComboBox( |
71 | 96 | self, |
72 | 97 | wx.ID_ANY, |
73 | - choices=list(self.plaidml_devices.keys()), | |
74 | - value=list(self.plaidml_devices.keys())[0], | |
98 | + choices=choices, | |
99 | + value=value, | |
75 | 100 | style=wx.CB_DROPDOWN | wx.CB_READONLY, |
76 | 101 | ) |
77 | 102 | self.sld_threshold = wx.Slider(self, wx.ID_ANY, 75, 0, 100) |
... | ... | @@ -109,7 +134,7 @@ class BrainSegmenterDialog(wx.Dialog): |
109 | 134 | main_sizer.Add(sizer_backends, 0, wx.ALL | wx.EXPAND, 5) |
110 | 135 | main_sizer.Add(self.chk_use_gpu, 0, wx.ALL, 5) |
111 | 136 | sizer_devices = wx.BoxSizer(wx.HORIZONTAL) |
112 | - if HAS_PLAIDML: | |
137 | + if HAS_TORCH or HAS_PLAIDML: | |
113 | 138 | sizer_devices.Add(self.lbl_device, 0, wx.ALIGN_CENTER, 0) |
114 | 139 | sizer_devices.Add(self.cb_devices, 1, wx.LEFT, 5) |
115 | 140 | main_sizer.Add(sizer_devices, 0, wx.ALL | wx.EXPAND, 5) |
... | ... | @@ -177,8 +202,21 @@ class BrainSegmenterDialog(wx.Dialog): |
177 | 202 | return width, height |
178 | 203 | |
179 | 204 | def OnSetBackend(self, evt=None): |
180 | - if self.cb_backends.GetValue().lower() == "plaidml": | |
205 | + if self.cb_backends.GetValue().lower() == "pytorch": | |
206 | + if HAS_TORCH: | |
207 | + choices = list(self.torch_devices.keys()) | |
208 | + self.cb_devices.Clear() | |
209 | + self.cb_devices.SetItems(choices) | |
210 | + self.cb_devices.SetValue(choices[0]) | |
211 | + self.lbl_device.Show() | |
212 | + self.cb_devices.Show() | |
213 | + self.chk_use_gpu.Hide() | |
214 | + elif self.cb_backends.GetValue().lower() == "plaidml": | |
181 | 215 | if HAS_PLAIDML: |
216 | + choices = list(self.plaidml_devices.keys()) | |
217 | + self.cb_devices.Clear() | |
218 | + self.cb_devices.SetItems(choices) | |
219 | + self.cb_devices.SetValue(choices[0]) | |
182 | 220 | self.lbl_device.Show() |
183 | 221 | self.cb_devices.Show() |
184 | 222 | self.chk_use_gpu.Hide() |
... | ... | @@ -216,10 +254,16 @@ class BrainSegmenterDialog(wx.Dialog): |
216 | 254 | self.elapsed_time_timer.Start(1000) |
217 | 255 | image = slc.Slice().matrix |
218 | 256 | backend = self.cb_backends.GetValue() |
219 | - try: | |
220 | - device_id = self.plaidml_devices[self.cb_devices.GetValue()] | |
221 | - except (KeyError, AttributeError): | |
222 | - device_id = "llvm_cpu.0" | |
257 | + if backend.lower() == "pytorch": | |
258 | + try: | |
259 | + device_id = self.torch_devices[self.cb_devices.GetValue()] | |
260 | + except (KeyError, AttributeError): | |
261 | + device_id = "cpu" | |
262 | + else: | |
263 | + try: | |
264 | + device_id = self.plaidml_devices[self.cb_devices.GetValue()] | |
265 | + except (KeyError, AttributeError): | |
266 | + device_id = "llvm_cpu.0" | |
223 | 267 | apply_wwwl = self.chk_apply_wwwl.GetValue() |
224 | 268 | create_new_mask = self.chk_new_mask.GetValue() |
225 | 269 | use_gpu = self.chk_use_gpu.GetValue() | ... | ... |
invesalius/gui/dialogs.py
... | ... | @@ -4632,7 +4632,8 @@ class SetNDIconfigs(wx.Dialog): |
4632 | 4632 | self._init_gui() |
4633 | 4633 | |
4634 | 4634 | def serial_ports(self): |
4635 | - """ Lists serial port names and pre-select the description containing NDI | |
4635 | + """ | |
4636 | + Lists serial port names and pre-select the description containing NDI | |
4636 | 4637 | """ |
4637 | 4638 | import serial.tools.list_ports |
4638 | 4639 | |
... | ... | @@ -4748,13 +4749,16 @@ class SetNDIconfigs(wx.Dialog): |
4748 | 4749 | return self.com_ports.GetString(self.com_ports.GetSelection()).encode(const.FS_ENCODE), fn_probe, fn_ref, fn_obj |
4749 | 4750 | |
4750 | 4751 | |
4751 | -class SetCOMport(wx.Dialog): | |
4752 | - def __init__(self, title=_("Select COM port")): | |
4753 | - wx.Dialog.__init__(self, wx.GetApp().GetTopWindow(), -1, title, style=wx.DEFAULT_DIALOG_STYLE|wx.FRAME_FLOAT_ON_PARENT|wx.STAY_ON_TOP) | |
4752 | +class SetCOMPort(wx.Dialog): | |
4753 | + def __init__(self, select_baud_rate, title=_("Select COM port")): | |
4754 | + wx.Dialog.__init__(self, wx.GetApp().GetTopWindow(), -1, title, style=wx.DEFAULT_DIALOG_STYLE | wx.FRAME_FLOAT_ON_PARENT | wx.STAY_ON_TOP) | |
4755 | + | |
4756 | + self.select_baud_rate = select_baud_rate | |
4754 | 4757 | self._init_gui() |
4755 | 4758 | |
4756 | 4759 | def serial_ports(self): |
4757 | - """ Lists serial port names | |
4760 | + """ | |
4761 | + Lists serial port names | |
4758 | 4762 | """ |
4759 | 4763 | import serial.tools.list_ports |
4760 | 4764 | if sys.platform.startswith('win'): |
... | ... | @@ -4764,12 +4768,26 @@ class SetCOMport(wx.Dialog): |
4764 | 4768 | return ports |
4765 | 4769 | |
4766 | 4770 | def _init_gui(self): |
4767 | - self.com_ports = wx.ComboBox(self, -1, style=wx.CB_DROPDOWN|wx.CB_READONLY) | |
4771 | + # COM port selection | |
4768 | 4772 | ports = self.serial_ports() |
4769 | - self.com_ports.Append(ports) | |
4773 | + self.com_port_dropdown = wx.ComboBox(self, -1, choices=ports, style=wx.CB_DROPDOWN | wx.CB_READONLY) | |
4774 | + self.com_port_dropdown.SetSelection(0) | |
4775 | + | |
4776 | + com_port_text_and_dropdown = wx.BoxSizer(wx.VERTICAL) | |
4777 | + com_port_text_and_dropdown.Add(wx.StaticText(self, wx.ID_ANY, "COM port"), 0, wx.TOP | wx.RIGHT,5) | |
4778 | + com_port_text_and_dropdown.Add(self.com_port_dropdown, 0, wx.EXPAND) | |
4779 | + | |
4780 | + # Baud rate selection | |
4781 | + if self.select_baud_rate: | |
4782 | + baud_rates_as_strings = [str(baud_rate) for baud_rate in const.BAUD_RATES] | |
4783 | + self.baud_rate_dropdown = wx.ComboBox(self, -1, choices=baud_rates_as_strings, style=wx.CB_DROPDOWN | wx.CB_READONLY) | |
4784 | + self.baud_rate_dropdown.SetSelection(const.BAUD_RATE_DEFAULT_SELECTION) | |
4770 | 4785 | |
4771 | - # self.goto_orientation.SetSelection(cb_init) | |
4786 | + baud_rate_text_and_dropdown = wx.BoxSizer(wx.VERTICAL) | |
4787 | + baud_rate_text_and_dropdown.Add(wx.StaticText(self, wx.ID_ANY, "Baud rate"), 0, wx.TOP | wx.RIGHT,5) | |
4788 | + baud_rate_text_and_dropdown.Add(self.baud_rate_dropdown, 0, wx.EXPAND) | |
4772 | 4789 | |
4790 | + # OK and Cancel buttons | |
4773 | 4791 | btn_ok = wx.Button(self, wx.ID_OK) |
4774 | 4792 | btn_ok.SetHelpText("") |
4775 | 4793 | btn_ok.SetDefault() |
... | ... | @@ -4782,10 +4800,16 @@ class SetCOMport(wx.Dialog): |
4782 | 4800 | btnsizer.AddButton(btn_cancel) |
4783 | 4801 | btnsizer.Realize() |
4784 | 4802 | |
4803 | + # Set up the main sizer | |
4785 | 4804 | main_sizer = wx.BoxSizer(wx.VERTICAL) |
4786 | 4805 | |
4787 | 4806 | main_sizer.Add((5, 5)) |
4788 | - main_sizer.Add(self.com_ports, 1, wx.EXPAND|wx.LEFT|wx.RIGHT, 5) | |
4807 | + main_sizer.Add(com_port_text_and_dropdown, 1, wx.EXPAND | wx.LEFT | wx.RIGHT, 5) | |
4808 | + | |
4809 | + if self.select_baud_rate: | |
4810 | + main_sizer.Add((5, 5)) | |
4811 | + main_sizer.Add(baud_rate_text_and_dropdown, 1, wx.EXPAND | wx.LEFT | wx.RIGHT, 5) | |
4812 | + | |
4789 | 4813 | main_sizer.Add((5, 5)) |
4790 | 4814 | main_sizer.Add(btnsizer, 0, wx.EXPAND) |
4791 | 4815 | main_sizer.Add((5, 5)) |
... | ... | @@ -4796,7 +4820,14 @@ class SetCOMport(wx.Dialog): |
4796 | 4820 | self.CenterOnParent() |
4797 | 4821 | |
4798 | 4822 | def GetValue(self): |
4799 | - return self.com_ports.GetString(self.com_ports.GetSelection()) | |
4823 | + com_port = self.com_port_dropdown.GetString(self.com_port_dropdown.GetSelection()) | |
4824 | + | |
4825 | + if self.select_baud_rate: | |
4826 | + baud_rate = self.baud_rate_dropdown.GetString(self.baud_rate_dropdown.GetSelection()) | |
4827 | + else: | |
4828 | + baud_rate = None | |
4829 | + | |
4830 | + return com_port, baud_rate | |
4800 | 4831 | |
4801 | 4832 | |
4802 | 4833 | class ManualWWWLDialog(wx.Dialog): | ... | ... |
invesalius/gui/task_navigator.py
... | ... | @@ -234,8 +234,8 @@ class InnerFoldPanel(wx.Panel): |
234 | 234 | checkcamera.Bind(wx.EVT_CHECKBOX, self.OnVolumeCamera) |
235 | 235 | self.checkcamera = checkcamera |
236 | 236 | |
237 | - # Check box to create markers from serial port | |
238 | - tooltip = wx.ToolTip(_("Enable serial port communication for creating markers")) | |
237 | + # Check box to use serial port to trigger pulse signal and create markers | |
238 | + tooltip = wx.ToolTip(_("Enable serial port communication to trigger pulse and create markers")) | |
239 | 239 | checkbox_serial_port = wx.CheckBox(self, -1, _('Serial port')) |
240 | 240 | checkbox_serial_port.SetToolTip(tooltip) |
241 | 241 | checkbox_serial_port.SetValue(False) |
... | ... | @@ -297,14 +297,20 @@ class InnerFoldPanel(wx.Panel): |
297 | 297 | self.checkobj.Enable(True) |
298 | 298 | |
299 | 299 | def OnEnableSerialPort(self, evt, ctrl): |
300 | - com_port = None | |
301 | 300 | if ctrl.GetValue(): |
302 | 301 | from wx import ID_OK |
303 | - dlg_port = dlg.SetCOMport() | |
304 | - if dlg_port.ShowModal() == ID_OK: | |
305 | - com_port = dlg_port.GetValue() | |
302 | + dlg_port = dlg.SetCOMPort(select_baud_rate=False) | |
306 | 303 | |
307 | - Publisher.sendMessage('Update serial port', serial_port=com_port) | |
304 | + if dlg_port.ShowModal() != ID_OK: | |
305 | + ctrl.SetValue(False) | |
306 | + return | |
307 | + | |
308 | + com_port = dlg_port.GetValue() | |
309 | + baud_rate = 115200 | |
310 | + | |
311 | + Publisher.sendMessage('Update serial port', serial_port_in_use=True, com_port=com_port, baud_rate=baud_rate) | |
312 | + else: | |
313 | + Publisher.sendMessage('Update serial port', serial_port_in_use=False) | |
308 | 314 | |
309 | 315 | def OnShowObject(self, evt=None, flag=None, obj_name=None, polydata=None, use_default_object=True): |
310 | 316 | if not evt: |
... | ... | @@ -503,7 +509,6 @@ class NeuronavigationPanel(wx.Panel): |
503 | 509 | Publisher.subscribe(self.LoadImageFiducials, 'Load image fiducials') |
504 | 510 | Publisher.subscribe(self.SetImageFiducial, 'Set image fiducial') |
505 | 511 | Publisher.subscribe(self.SetTrackerFiducial, 'Set tracker fiducial') |
506 | - Publisher.subscribe(self.UpdateSerialPort, 'Update serial port') | |
507 | 512 | Publisher.subscribe(self.UpdateTrackObjectState, 'Update track object state') |
508 | 513 | Publisher.subscribe(self.UpdateImageCoordinates, 'Set cross focal point') |
509 | 514 | Publisher.subscribe(self.OnDisconnectTracker, 'Disconnect tracker') |
... | ... | @@ -627,9 +632,6 @@ class NeuronavigationPanel(wx.Panel): |
627 | 632 | def UpdateTrackObjectState(self, evt=None, flag=None, obj_name=None, polydata=None, use_default_object=True): |
628 | 633 | self.navigation.track_obj = flag |
629 | 634 | |
630 | - def UpdateSerialPort(self, serial_port): | |
631 | - self.navigation.serial_port = serial_port | |
632 | - | |
633 | 635 | def ResetICP(self): |
634 | 636 | self.icp.ResetICP() |
635 | 637 | self.checkbox_icp.Enable(False) | ... | ... |
invesalius/inv_paths.py
... | ... | @@ -27,6 +27,7 @@ CONF_DIR = pathlib.Path(os.environ.get("XDG_CONFIG_HOME", USER_DIR.joinpath(".co |
27 | 27 | USER_INV_DIR = CONF_DIR.joinpath("invesalius") |
28 | 28 | USER_PRESET_DIR = USER_INV_DIR.joinpath("presets") |
29 | 29 | USER_LOG_DIR = USER_INV_DIR.joinpath("logs") |
30 | +USER_DL_WEIGHTS = USER_INV_DIR.joinpath("deep_learning/weights/") | |
30 | 31 | USER_RAYCASTING_PRESETS_DIRECTORY = USER_PRESET_DIR.joinpath("raycasting") |
31 | 32 | TEMP_DIR = tempfile.gettempdir() |
32 | 33 | |
... | ... | @@ -97,6 +98,7 @@ def create_conf_folders(): |
97 | 98 | USER_INV_DIR.mkdir(parents=True, exist_ok=True) |
98 | 99 | USER_PRESET_DIR.mkdir(parents=True, exist_ok=True) |
99 | 100 | USER_LOG_DIR.mkdir(parents=True, exist_ok=True) |
101 | + USER_DL_WEIGHTS.mkdir(parents=True, exist_ok=True) | |
100 | 102 | USER_PLUGINS_DIRECTORY.mkdir(parents=True, exist_ok=True) |
101 | 103 | |
102 | 104 | ... | ... |
invesalius/navigation/navigation.py
... | ... | @@ -171,7 +171,9 @@ class Navigation(): |
171 | 171 | self.sleep_nav = const.SLEEP_NAVIGATION |
172 | 172 | |
173 | 173 | # Serial port |
174 | - self.serial_port = None | |
174 | + self.serial_port_in_use = False | |
175 | + self.com_port = None | |
176 | + self.baud_rate = None | |
175 | 177 | self.serial_port_connection = None |
176 | 178 | |
177 | 179 | # During navigation |
... | ... | @@ -181,6 +183,7 @@ class Navigation(): |
181 | 183 | |
182 | 184 | def __bind_events(self): |
183 | 185 | Publisher.subscribe(self.CoilAtTarget, 'Coil at target') |
186 | + Publisher.subscribe(self.UpdateSerialPort, 'Update serial port') | |
184 | 187 | |
185 | 188 | def CoilAtTarget(self, state): |
186 | 189 | self.coil_at_target = state |
... | ... | @@ -189,8 +192,10 @@ class Navigation(): |
189 | 192 | self.sleep_nav = sleep |
190 | 193 | self.serial_port_connection.sleep_nav = sleep |
191 | 194 | |
192 | - def SerialPortEnabled(self): | |
193 | - return self.serial_port is not None | |
195 | + def UpdateSerialPort(self, serial_port_in_use, com_port=None, baud_rate=None): | |
196 | + self.serial_port_in_use = serial_port_in_use | |
197 | + self.com_port = com_port | |
198 | + self.baud_rate = baud_rate | |
194 | 199 | |
195 | 200 | def SetReferenceMode(self, value): |
196 | 201 | self.ref_mode_id = value |
... | ... | @@ -218,7 +223,7 @@ class Navigation(): |
218 | 223 | return fre, fre <= const.FIDUCIAL_REGISTRATION_ERROR_THRESHOLD |
219 | 224 | |
220 | 225 | def PedalStateChanged(self, state): |
221 | - if state is True and self.coil_at_target and self.SerialPortEnabled(): | |
226 | + if state is True and self.coil_at_target and self.serial_port_in_use: | |
222 | 227 | self.serial_port_connection.SendPulse() |
223 | 228 | |
224 | 229 | def StartNavigation(self, tracker): |
... | ... | @@ -230,7 +235,7 @@ class Navigation(): |
230 | 235 | if self.event.is_set(): |
231 | 236 | self.event.clear() |
232 | 237 | |
233 | - vis_components = [self.SerialPortEnabled(), self.view_tracts, self.peel_loaded] | |
238 | + vis_components = [self.serial_port_in_use, self.view_tracts, self.peel_loaded] | |
234 | 239 | vis_queues = [self.coord_queue, self.serial_port_queue, self.tracts_queue, self.icp_queue, self.robottarget_queue] |
235 | 240 | |
236 | 241 | Publisher.sendMessage("Navigation status", nav_status=True, vis_status=vis_components) |
... | ... | @@ -279,12 +284,13 @@ class Navigation(): |
279 | 284 | |
280 | 285 | if not errors: |
281 | 286 | #TODO: Test the serial port thread |
282 | - if self.SerialPortEnabled(): | |
287 | + if self.serial_port_in_use: | |
283 | 288 | self.serial_port_connection = spc.SerialPortConnection( |
284 | - self.serial_port, | |
285 | - self.serial_port_queue, | |
286 | - self.event, | |
287 | - self.sleep_nav, | |
289 | + com_port=self.com_port, | |
290 | + baud_rate=self.baud_rate, | |
291 | + serial_port_queue=self.serial_port_queue, | |
292 | + event=self.event, | |
293 | + sleep_nav=self.sleep_nav, | |
288 | 294 | ) |
289 | 295 | self.serial_port_connection.Connect() |
290 | 296 | jobs_list.append(self.serial_port_connection) |
... | ... | @@ -330,7 +336,7 @@ class Navigation(): |
330 | 336 | if self.serial_port_connection is not None: |
331 | 337 | self.serial_port_connection.join() |
332 | 338 | |
333 | - if self.SerialPortEnabled(): | |
339 | + if self.serial_port_in_use: | |
334 | 340 | self.serial_port_queue.clear() |
335 | 341 | self.serial_port_queue.join() |
336 | 342 | |
... | ... | @@ -341,5 +347,5 @@ class Navigation(): |
341 | 347 | self.tracts_queue.clear() |
342 | 348 | self.tracts_queue.join() |
343 | 349 | |
344 | - vis_components = [self.SerialPortEnabled(), self.view_tracts, self.peel_loaded] | |
350 | + vis_components = [self.serial_port_in_use, self.view_tracts, self.peel_loaded] | |
345 | 351 | Publisher.sendMessage("Navigation status", nav_status=False, vis_status=vis_components) | ... | ... |
... | ... | @@ -0,0 +1,48 @@ |
1 | +from urllib.error import HTTPError | |
2 | +from urllib.request import urlopen, Request | |
3 | +from urllib.parse import urlparse | |
4 | +import pathlib | |
5 | +import tempfile | |
6 | +import typing | |
7 | +import hashlib | |
8 | +import os | |
9 | +import shutil | |
10 | + | |
11 | +def download_url_to_file(url: str, dst: pathlib.Path, hash: str = None, callback: typing.Callable[[float], None] = None): | |
12 | + file_size = None | |
13 | + total_downloaded = 0 | |
14 | + if hash is not None: | |
15 | + calc_hash = hashlib.sha256() | |
16 | + req = Request(url) | |
17 | + response = urlopen(req) | |
18 | + meta = response.info() | |
19 | + if hasattr(meta, "getheaders"): | |
20 | + content_length = meta.getheaders("Content-Length") | |
21 | + else: | |
22 | + content_length = meta.get_all("Content-Length") | |
23 | + | |
24 | + if content_length is not None and len(content_length) > 0: | |
25 | + file_size = int(content_length[0]) | |
26 | + dst.parent.mkdir(parents=True, exist_ok=True) | |
27 | + f = tempfile.NamedTemporaryFile(delete=False, dir=dst.parent) | |
28 | + try: | |
29 | + while True: | |
30 | + buffer = response.read(8192) | |
31 | + if len(buffer) == 0: | |
32 | + break | |
33 | + total_downloaded += len(buffer) | |
34 | + f.write(buffer) | |
35 | + if hash: | |
36 | + calc_hash.update(buffer) | |
37 | + if callback is not None: | |
38 | + callback(100 * total_downloaded/file_size) | |
39 | + f.close() | |
40 | + if hash is not None: | |
41 | + digest = calc_hash.hexdigest() | |
42 | + if digest != hash: | |
43 | + raise RuntimeError(f'Invalid hash value (expected "{hash}", got "{digest}")') | |
44 | + shutil.move(f.name, dst) | |
45 | + finally: | |
46 | + f.close() | |
47 | + if os.path.exists(f.name): | |
48 | + os.remove(f.name) | ... | ... |
... | ... | @@ -0,0 +1,149 @@ |
1 | +from collections import OrderedDict | |
2 | + | |
3 | +import torch | |
4 | +import torch.nn as nn | |
5 | + | |
6 | +SIZE = 48 | |
7 | + | |
8 | +class Unet3D(nn.Module): | |
9 | + # Based on https://github.com/mateuszbuda/brain-segmentation-pytorch/blob/master/unet.py | |
10 | + def __init__(self, in_channels=1, out_channels=1, init_features=8): | |
11 | + super().__init__() | |
12 | + features = init_features | |
13 | + | |
14 | + self.encoder1 = self._block( | |
15 | + in_channels, features=features, padding=2, name="enc1" | |
16 | + ) | |
17 | + self.pool1 = nn.MaxPool3d(kernel_size=2, stride=2) | |
18 | + | |
19 | + self.encoder2 = self._block( | |
20 | + features, features=features * 2, padding=2, name="enc2" | |
21 | + ) | |
22 | + self.pool2 = nn.MaxPool3d(kernel_size=2, stride=2) | |
23 | + | |
24 | + self.encoder3 = self._block( | |
25 | + features * 2, features=features * 4, padding=2, name="enc3" | |
26 | + ) | |
27 | + self.pool3 = nn.MaxPool3d(kernel_size=2, stride=2) | |
28 | + | |
29 | + self.encoder4 = self._block( | |
30 | + features * 4, features=features * 8, padding=2, name="enc4" | |
31 | + ) | |
32 | + self.pool4 = nn.MaxPool3d(kernel_size=2, stride=2) | |
33 | + | |
34 | + self.bottleneck = self._block( | |
35 | + features * 8, features=features * 16, padding=2, name="bottleneck" | |
36 | + ) | |
37 | + | |
38 | + self.upconv4 = nn.ConvTranspose3d( | |
39 | + features * 16, features * 8, kernel_size=4, stride=2, padding=1 | |
40 | + ) | |
41 | + self.decoder4 = self._block( | |
42 | + features * 16, features=features * 8, padding=2, name="dec4" | |
43 | + ) | |
44 | + | |
45 | + self.upconv3 = nn.ConvTranspose3d( | |
46 | + features * 8, features * 4, kernel_size=4, stride=2, padding=1 | |
47 | + ) | |
48 | + self.decoder3 = self._block( | |
49 | + features * 8, features=features * 4, padding=2, name="dec4" | |
50 | + ) | |
51 | + | |
52 | + self.upconv2 = nn.ConvTranspose3d( | |
53 | + features * 4, features * 2, kernel_size=4, stride=2, padding=1 | |
54 | + ) | |
55 | + self.decoder2 = self._block( | |
56 | + features * 4, features=features * 2, padding=2, name="dec4" | |
57 | + ) | |
58 | + | |
59 | + self.upconv1 = nn.ConvTranspose3d( | |
60 | + features * 2, features, kernel_size=4, stride=2, padding=1 | |
61 | + ) | |
62 | + self.decoder1 = self._block( | |
63 | + features * 2, features=features, padding=2, name="dec4" | |
64 | + ) | |
65 | + | |
66 | + self.conv = nn.Conv3d( | |
67 | + in_channels=features, out_channels=out_channels, kernel_size=1 | |
68 | + ) | |
69 | + | |
70 | + def forward(self, img): | |
71 | + enc1 = self.encoder1(img) | |
72 | + enc2 = self.encoder2(self.pool1(enc1)) | |
73 | + enc3 = self.encoder3(self.pool2(enc2)) | |
74 | + enc4 = self.encoder4(self.pool3(enc3)) | |
75 | + | |
76 | + bottleneck = self.bottleneck(self.pool4(enc4)) | |
77 | + | |
78 | + upconv4 = self.upconv4(bottleneck) | |
79 | + dec4 = torch.cat((upconv4, enc4), dim=1) | |
80 | + dec4 = self.decoder4(dec4) | |
81 | + | |
82 | + upconv3 = self.upconv3(dec4) | |
83 | + dec3 = torch.cat((upconv3, enc3), dim=1) | |
84 | + dec3 = self.decoder3(dec3) | |
85 | + | |
86 | + upconv2 = self.upconv2(dec3) | |
87 | + dec2 = torch.cat((upconv2, enc2), dim=1) | |
88 | + dec2 = self.decoder2(dec2) | |
89 | + | |
90 | + upconv1 = self.upconv1(dec2) | |
91 | + dec1 = torch.cat((upconv1, enc1), dim=1) | |
92 | + dec1 = self.decoder1(dec1) | |
93 | + | |
94 | + conv = self.conv(dec1) | |
95 | + | |
96 | + sigmoid = torch.sigmoid(conv) | |
97 | + | |
98 | + return sigmoid | |
99 | + | |
100 | + def _block(self, in_channels, features, padding=1, kernel_size=5, name="block"): | |
101 | + return nn.Sequential( | |
102 | + OrderedDict( | |
103 | + ( | |
104 | + ( | |
105 | + f"{name}_conv1", | |
106 | + nn.Conv3d( | |
107 | + in_channels=in_channels, | |
108 | + out_channels=features, | |
109 | + kernel_size=kernel_size, | |
110 | + padding=padding, | |
111 | + bias=True, | |
112 | + ), | |
113 | + ), | |
114 | + (f"{name}_norm1", nn.BatchNorm3d(num_features=features)), | |
115 | + (f"{name}_relu1", nn.ReLU(inplace=True)), | |
116 | + ( | |
117 | + f"{name}_conv2", | |
118 | + nn.Conv3d( | |
119 | + in_channels=features, | |
120 | + out_channels=features, | |
121 | + kernel_size=kernel_size, | |
122 | + padding=padding, | |
123 | + bias=True, | |
124 | + ), | |
125 | + ), | |
126 | + (f"{name}_norm2", nn.BatchNorm3d(num_features=features)), | |
127 | + (f"{name}_relu2", nn.ReLU(inplace=True)), | |
128 | + ) | |
129 | + ) | |
130 | + ) | |
131 | + | |
132 | + | |
133 | +def main(): | |
134 | + import torchviz | |
135 | + dev = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
136 | + model = Unet3D() | |
137 | + model.to(dev) | |
138 | + model.eval() | |
139 | + print(next(model.parameters()).is_cuda) # True | |
140 | + img = torch.randn(1, SIZE, SIZE, SIZE, 1).to(dev) | |
141 | + out = model(img) | |
142 | + dot = torchviz.make_dot(out, params=dict(model.named_parameters()), show_attrs=True, show_saved=True) | |
143 | + dot.render("unet", format="png") | |
144 | + torch.save(model, "model.pth") | |
145 | + print(dot) | |
146 | + | |
147 | + | |
148 | +if __name__ == "__main__": | |
149 | + main() | ... | ... |
invesalius/segmentation/brain/segment.py
... | ... | @@ -13,6 +13,8 @@ import invesalius.data.slice_ as slc |
13 | 13 | from invesalius import inv_paths |
14 | 14 | from invesalius.data import imagedata_utils |
15 | 15 | from invesalius.utils import new_name_by_pattern |
16 | +from invesalius.net.utils import download_url_to_file | |
17 | +from invesalius import inv_paths | |
16 | 18 | |
17 | 19 | from . import utils |
18 | 20 | |
... | ... | @@ -64,6 +66,17 @@ def predict_patch(sub_image, patch, nn_model, patch_size=SIZE): |
64 | 66 | 0 : ez - iz, 0 : ey - iy, 0 : ex - ix |
65 | 67 | ] |
66 | 68 | |
69 | +def predict_patch_torch(sub_image, patch, nn_model, device, patch_size=SIZE): | |
70 | + import torch | |
71 | + with torch.no_grad(): | |
72 | + (iz, ez), (iy, ey), (ix, ex) = patch | |
73 | + sub_mask = nn_model( | |
74 | + torch.from_numpy(sub_image.reshape(1, 1, patch_size, patch_size, patch_size)).to(device) | |
75 | + ).cpu().numpy() | |
76 | + return sub_mask.reshape(patch_size, patch_size, patch_size)[ | |
77 | + 0 : ez - iz, 0 : ey - iy, 0 : ex - ix | |
78 | + ] | |
79 | + | |
67 | 80 | |
68 | 81 | def brain_segment(image, probability_array, comm_array): |
69 | 82 | import keras |
... | ... | @@ -89,6 +102,42 @@ def brain_segment(image, probability_array, comm_array): |
89 | 102 | comm_array[0] = np.Inf |
90 | 103 | |
91 | 104 | |
105 | +def download_callback(comm_array): | |
106 | + def _download_callback(value): | |
107 | + comm_array[0] = value | |
108 | + return _download_callback | |
109 | + | |
110 | +def brain_segment_torch(image, device_id, probability_array, comm_array): | |
111 | + import torch | |
112 | + from .model import Unet3D | |
113 | + device = torch.device(device_id) | |
114 | + state_dict_file = inv_paths.USER_DL_WEIGHTS.joinpath("brain_mri_t1.pt") | |
115 | + if not state_dict_file.exists(): | |
116 | + download_url_to_file( | |
117 | + "https://github.com/tfmoraes/deepbrain_torch/releases/download/v1.1.0/weights.pt", | |
118 | + state_dict_file, | |
119 | + "194b0305947c9326eeee9da34ada728435a13c7b24015cbd95971097fc178f22", | |
120 | + download_callback(comm_array) | |
121 | + ) | |
122 | + state_dict = torch.load(str(state_dict_file)) | |
123 | + model = Unet3D() | |
124 | + model.load_state_dict(state_dict["model_state_dict"]) | |
125 | + model.to(device) | |
126 | + model.eval() | |
127 | + | |
128 | + image = imagedata_utils.image_normalize(image, 0.0, 1.0, output_dtype=np.float32) | |
129 | + sums = np.zeros_like(image) | |
130 | + # segmenting by patches | |
131 | + for completion, sub_image, patch in gen_patches(image, SIZE, OVERLAP): | |
132 | + comm_array[0] = completion | |
133 | + (iz, ez), (iy, ey), (ix, ex) = patch | |
134 | + sub_mask = predict_patch_torch(sub_image, patch, model, device, SIZE) | |
135 | + probability_array[iz:ez, iy:ey, ix:ex] += sub_mask | |
136 | + sums[iz:ez, iy:ey, ix:ex] += 1 | |
137 | + | |
138 | + probability_array /= sums | |
139 | + comm_array[0] = np.Inf | |
140 | + | |
92 | 141 | ctx = multiprocessing.get_context('spawn') |
93 | 142 | class SegmentProcess(ctx.Process): |
94 | 143 | def __init__(self, image, create_new_mask, backend, device_id, use_gpu, apply_wwwl=False, window_width=255, window_level=127): |
... | ... | @@ -138,8 +187,7 @@ class SegmentProcess(ctx.Process): |
138 | 187 | mode="r", |
139 | 188 | ) |
140 | 189 | |
141 | - print(image.min(), image.max()) | |
142 | - if self.apply_segment_threshold: | |
190 | + if self.apply_wwwl: | |
143 | 191 | print("Applying window level") |
144 | 192 | image = get_LUT_value(image, self.window_width, self.window_level) |
145 | 193 | |
... | ... | @@ -153,8 +201,11 @@ class SegmentProcess(ctx.Process): |
153 | 201 | self._comm_array_filename, dtype=np.float32, shape=(1,), mode="r+" |
154 | 202 | ) |
155 | 203 | |
156 | - utils.prepare_ambient(self.backend, self.device_id, self.use_gpu) | |
157 | - brain_segment(image, probability_array, comm_array) | |
204 | + if self.backend.lower() == "pytorch": | |
205 | + brain_segment_torch(image, self.device_id, probability_array, comm_array) | |
206 | + else: | |
207 | + utils.prepare_ambient(self.backend, self.device_id, self.use_gpu) | |
208 | + brain_segment(image, probability_array, comm_array) | |
158 | 209 | |
159 | 210 | @property |
160 | 211 | def exception(self): | ... | ... |
optional-requirements.txt
requirements.txt
... | ... | @@ -0,0 +1,88 @@ |
1 | +#!/usr/bin/env python3 | |
2 | +# -*- coding: utf-8 -*- | |
3 | + | |
4 | +# This scripts allows sending events to InVesalius via Socket.IO, mimicking InVesalius's | |
5 | +# internal communication. It can be useful for developing and debugging InVesalius. | |
6 | +# | |
7 | +# Example usage: | |
8 | +# | |
9 | +# - (In console window 1) Run the script by: python scripts/invesalius_server.py 5000 | |
10 | +# | |
11 | +# - (In console window 2) Run InVesalius by: python app.py --remote-host http://localhost:5000 | |
12 | +# | |
13 | +# - If InVesalius connected to the server successfully, a message should appear in console window 1, | |
14 | +# asking to provide the topic name. | |
15 | +# | |
16 | +# - Enter the topic name, such as "Add marker" (without quotes). | |
17 | +# | |
18 | +# - Enter the data, such as {"ball_id": 0, "size": 2, "colour": [1.0, 1.0, 0.0], "coord": [10.0, 20.0, 30.0]} | |
19 | +# | |
20 | +# - If successful, a message should now appear in console window 2, indicating that the event was received. | |
21 | + | |
22 | +import asyncio | |
23 | +import sys | |
24 | +import json | |
25 | + | |
26 | +import aioconsole | |
27 | +import nest_asyncio | |
28 | +import socketio | |
29 | +import uvicorn | |
30 | + | |
31 | +nest_asyncio.apply() | |
32 | + | |
33 | +if len(sys.argv) != 2: | |
34 | + print ("""This script allows sending events to InVesalius. | |
35 | + | |
36 | +Usage: python invesalius_server.py port""") | |
37 | + sys.exit(1) | |
38 | + | |
39 | +port = int(sys.argv[1]) | |
40 | + | |
41 | +sio = socketio.AsyncServer(async_mode='asgi') | |
42 | +app = socketio.ASGIApp(sio) | |
43 | + | |
44 | +connected = False | |
45 | + | |
46 | +@sio.event | |
47 | +def connect(sid, environ): | |
48 | + global connected | |
49 | + connected = True | |
50 | + | |
51 | +def print_json_error(e): | |
52 | + print("Invalid JSON") | |
53 | + print(e.doc) | |
54 | + print(" " * e.pos + "^") | |
55 | + print(e.msg) | |
56 | + print("") | |
57 | + | |
58 | +async def run(): | |
59 | + while True: | |
60 | + if not connected: | |
61 | + await asyncio.sleep(1) | |
62 | + continue | |
63 | + | |
64 | + print("Enter topic: ") | |
65 | + topic = await aioconsole.ainput() | |
66 | + print("Enter data as JSON: ") | |
67 | + data = await aioconsole.ainput() | |
68 | + | |
69 | + try: | |
70 | + decoded = json.loads(data) | |
71 | + except json.decoder.JSONDecodeError as e: | |
72 | + print_json_error(e) | |
73 | + continue | |
74 | + | |
75 | + await sio.emit( | |
76 | + event="to_neuronavigation", | |
77 | + data={ | |
78 | + "topic": topic, | |
79 | + "data": decoded, | |
80 | + } | |
81 | + ) | |
82 | + | |
83 | +async def main(): | |
84 | + asyncio.create_task(run()) | |
85 | + uvicorn.run(app, port=port, host='0.0.0.0', loop='asyncio') | |
86 | + | |
87 | +if __name__ == '__main__': | |
88 | + asyncio.run(main(), debug=True) | ... | ... |