Commit b6ae340436ae6e2556de9f1521b7ffc029cac8a6
Committed by
GitHub
Exists in
master
Merge branch 'master' into multimodal_tracking
Showing
14 changed files
with
489 additions
and
57 deletions
Show diff stats
invesalius/constants.py
@@ -830,8 +830,14 @@ TREKKER_CONFIG = {'seed_max': 1, 'step_size': 0.1, 'min_fod': 0.1, 'probe_qualit | @@ -830,8 +830,14 @@ TREKKER_CONFIG = {'seed_max': 1, 'step_size': 0.1, 'min_fod': 0.1, 'probe_qualit | ||
830 | 830 | ||
831 | MARKER_FILE_MAGICK_STRING = "INVESALIUS3_MARKER_FILE_" | 831 | MARKER_FILE_MAGICK_STRING = "INVESALIUS3_MARKER_FILE_" |
832 | CURRENT_MARKER_FILE_VERSION = 0 | 832 | CURRENT_MARKER_FILE_VERSION = 0 |
833 | + | ||
833 | WILDCARD_MARKER_FILES = _("Marker scanner coord files (*.mkss)|*.mkss") | 834 | WILDCARD_MARKER_FILES = _("Marker scanner coord files (*.mkss)|*.mkss") |
834 | 835 | ||
836 | +# Serial port | ||
837 | +BAUD_RATES = [300, 1200, 2400, 4800, 9600, 19200, 38400, 57600, 115200] | ||
838 | +BAUD_RATE_DEFAULT_SELECTION = 4 | ||
839 | + | ||
840 | +#Robot | ||
835 | ROBOT_ElFIN_IP = ['Select robot IP:', '143.107.220.251', '169.254.153.251', '127.0.0.1'] | 841 | ROBOT_ElFIN_IP = ['Select robot IP:', '143.107.220.251', '169.254.153.251', '127.0.0.1'] |
836 | ROBOT_ElFIN_PORT = 10003 | 842 | ROBOT_ElFIN_PORT = 10003 |
837 | 843 |
invesalius/data/serial_port_connection.py
@@ -28,7 +28,7 @@ from invesalius.pubsub import pub as Publisher | @@ -28,7 +28,7 @@ from invesalius.pubsub import pub as Publisher | ||
28 | class SerialPortConnection(threading.Thread): | 28 | class SerialPortConnection(threading.Thread): |
29 | BINARY_PULSE = b'\x01' | 29 | BINARY_PULSE = b'\x01' |
30 | 30 | ||
31 | - def __init__(self, port, serial_port_queue, event, sleep_nav): | 31 | + def __init__(self, com_port, baud_rate, serial_port_queue, event, sleep_nav): |
32 | """ | 32 | """ |
33 | Thread created to communicate using the serial port to interact with software during neuronavigation. | 33 | Thread created to communicate using the serial port to interact with software during neuronavigation. |
34 | """ | 34 | """ |
@@ -37,28 +37,29 @@ class SerialPortConnection(threading.Thread): | @@ -37,28 +37,29 @@ class SerialPortConnection(threading.Thread): | ||
37 | self.connection = None | 37 | self.connection = None |
38 | self.stylusplh = False | 38 | self.stylusplh = False |
39 | 39 | ||
40 | - self.port = port | 40 | + self.com_port = com_port |
41 | + self.baud_rate = baud_rate | ||
41 | self.serial_port_queue = serial_port_queue | 42 | self.serial_port_queue = serial_port_queue |
42 | self.event = event | 43 | self.event = event |
43 | self.sleep_nav = sleep_nav | 44 | self.sleep_nav = sleep_nav |
44 | 45 | ||
45 | def Connect(self): | 46 | def Connect(self): |
46 | - if self.port is None: | 47 | + if self.com_port is None: |
47 | print("Serial port init error: COM port is unset.") | 48 | print("Serial port init error: COM port is unset.") |
48 | return | 49 | return |
49 | try: | 50 | try: |
50 | import serial | 51 | import serial |
51 | - self.connection = serial.Serial(self.port, baudrate=115200, timeout=0) | ||
52 | - print("Connection to port {} opened.".format(self.port)) | 52 | + self.connection = serial.Serial(self.com_port, baudrate=self.baud_rate, timeout=0) |
53 | + print("Connection to port {} opened.".format(self.com_port)) | ||
53 | 54 | ||
54 | Publisher.sendMessage('Serial port connection', state=True) | 55 | Publisher.sendMessage('Serial port connection', state=True) |
55 | except: | 56 | except: |
56 | - print("Serial port init error: Connecting to port {} failed.".format(self.port)) | 57 | + print("Serial port init error: Connecting to port {} failed.".format(self.com_port)) |
57 | 58 | ||
58 | def Disconnect(self): | 59 | def Disconnect(self): |
59 | if self.connection: | 60 | if self.connection: |
60 | self.connection.close() | 61 | self.connection.close() |
61 | - print("Connection to port {} closed.".format(self.port)) | 62 | + print("Connection to port {} closed.".format(self.com_port)) |
62 | 63 | ||
63 | Publisher.sendMessage('Serial port connection', state=False) | 64 | Publisher.sendMessage('Serial port connection', state=False) |
64 | 65 | ||
@@ -74,12 +75,11 @@ class SerialPortConnection(threading.Thread): | @@ -74,12 +75,11 @@ class SerialPortConnection(threading.Thread): | ||
74 | trigger_on = False | 75 | trigger_on = False |
75 | try: | 76 | try: |
76 | lines = self.connection.readlines() | 77 | lines = self.connection.readlines() |
78 | + if lines: | ||
79 | + trigger_on = True | ||
77 | except: | 80 | except: |
78 | print("Error: Serial port could not be read.") | 81 | print("Error: Serial port could not be read.") |
79 | 82 | ||
80 | - if lines: | ||
81 | - trigger_on = True | ||
82 | - | ||
83 | if self.stylusplh: | 83 | if self.stylusplh: |
84 | trigger_on = True | 84 | trigger_on = True |
85 | self.stylusplh = False | 85 | self.stylusplh = False |
invesalius/data/trackers.py
@@ -299,7 +299,7 @@ def PlhSerialConnection(tracker_id): | @@ -299,7 +299,7 @@ def PlhSerialConnection(tracker_id): | ||
299 | import serial | 299 | import serial |
300 | from wx import ID_OK | 300 | from wx import ID_OK |
301 | trck_init = None | 301 | trck_init = None |
302 | - dlg_port = dlg.SetCOMport() | 302 | + dlg_port = dlg.SetCOMPort(select_baud_rate=False) |
303 | if dlg_port.ShowModal() == ID_OK: | 303 | if dlg_port.ShowModal() == ID_OK: |
304 | com_port = dlg_port.GetValue() | 304 | com_port = dlg_port.GetValue() |
305 | try: | 305 | try: |
invesalius/gui/brain_seg_dialog.py
@@ -22,6 +22,22 @@ HAS_THEANO = bool(importlib.util.find_spec("theano")) | @@ -22,6 +22,22 @@ HAS_THEANO = bool(importlib.util.find_spec("theano")) | ||
22 | HAS_PLAIDML = bool(importlib.util.find_spec("plaidml")) | 22 | HAS_PLAIDML = bool(importlib.util.find_spec("plaidml")) |
23 | PLAIDML_DEVICES = {} | 23 | PLAIDML_DEVICES = {} |
24 | 24 | ||
25 | +try: | ||
26 | + import torch | ||
27 | + HAS_TORCH = True | ||
28 | +except ImportError: | ||
29 | + HAS_TORCH = False | ||
30 | + | ||
31 | +if HAS_TORCH: | ||
32 | + TORCH_DEVICES = {} | ||
33 | + if torch.cuda.is_available(): | ||
34 | + for i in range(torch.cuda.device_count()): | ||
35 | + name = torch.cuda.get_device_name() | ||
36 | + device_id = f'cuda:{i}' | ||
37 | + TORCH_DEVICES[name] = device_id | ||
38 | + TORCH_DEVICES['CPU'] = 'cpu' | ||
39 | + | ||
40 | + | ||
25 | 41 | ||
26 | if HAS_PLAIDML: | 42 | if HAS_PLAIDML: |
27 | with multiprocessing.Pool(1) as p: | 43 | with multiprocessing.Pool(1) as p: |
@@ -43,12 +59,15 @@ class BrainSegmenterDialog(wx.Dialog): | @@ -43,12 +59,15 @@ class BrainSegmenterDialog(wx.Dialog): | ||
43 | style=wx.DEFAULT_DIALOG_STYLE | wx.FRAME_FLOAT_ON_PARENT, | 59 | style=wx.DEFAULT_DIALOG_STYLE | wx.FRAME_FLOAT_ON_PARENT, |
44 | ) | 60 | ) |
45 | backends = [] | 61 | backends = [] |
62 | + if HAS_TORCH: | ||
63 | + backends.append("Pytorch") | ||
46 | if HAS_PLAIDML: | 64 | if HAS_PLAIDML: |
47 | backends.append("PlaidML") | 65 | backends.append("PlaidML") |
48 | if HAS_THEANO: | 66 | if HAS_THEANO: |
49 | backends.append("Theano") | 67 | backends.append("Theano") |
50 | # self.segmenter = segment.BrainSegmenter() | 68 | # self.segmenter = segment.BrainSegmenter() |
51 | # self.pg_dialog = None | 69 | # self.pg_dialog = None |
70 | + self.torch_devices = TORCH_DEVICES | ||
52 | self.plaidml_devices = PLAIDML_DEVICES | 71 | self.plaidml_devices = PLAIDML_DEVICES |
53 | 72 | ||
54 | self.ps = None | 73 | self.ps = None |
@@ -65,13 +84,19 @@ class BrainSegmenterDialog(wx.Dialog): | @@ -65,13 +84,19 @@ class BrainSegmenterDialog(wx.Dialog): | ||
65 | w, h = self.CalcSizeFromTextSize("MM" * (1 + max(len(i) for i in backends))) | 84 | w, h = self.CalcSizeFromTextSize("MM" * (1 + max(len(i) for i in backends))) |
66 | self.cb_backends.SetMinClientSize((w, -1)) | 85 | self.cb_backends.SetMinClientSize((w, -1)) |
67 | self.chk_use_gpu = wx.CheckBox(self, wx.ID_ANY, _("Use GPU")) | 86 | self.chk_use_gpu = wx.CheckBox(self, wx.ID_ANY, _("Use GPU")) |
68 | - if HAS_PLAIDML: | 87 | + if HAS_TORCH or HAS_PLAIDML: |
88 | + if HAS_TORCH: | ||
89 | + choices = list(self.torch_devices.keys()) | ||
90 | + value = choices[0] | ||
91 | + else: | ||
92 | + choices = list(self.plaidml_devices.keys()) | ||
93 | + value = choices[0] | ||
69 | self.lbl_device = wx.StaticText(self, -1, _("Device")) | 94 | self.lbl_device = wx.StaticText(self, -1, _("Device")) |
70 | self.cb_devices = wx.ComboBox( | 95 | self.cb_devices = wx.ComboBox( |
71 | self, | 96 | self, |
72 | wx.ID_ANY, | 97 | wx.ID_ANY, |
73 | - choices=list(self.plaidml_devices.keys()), | ||
74 | - value=list(self.plaidml_devices.keys())[0], | 98 | + choices=choices, |
99 | + value=value, | ||
75 | style=wx.CB_DROPDOWN | wx.CB_READONLY, | 100 | style=wx.CB_DROPDOWN | wx.CB_READONLY, |
76 | ) | 101 | ) |
77 | self.sld_threshold = wx.Slider(self, wx.ID_ANY, 75, 0, 100) | 102 | self.sld_threshold = wx.Slider(self, wx.ID_ANY, 75, 0, 100) |
@@ -109,7 +134,7 @@ class BrainSegmenterDialog(wx.Dialog): | @@ -109,7 +134,7 @@ class BrainSegmenterDialog(wx.Dialog): | ||
109 | main_sizer.Add(sizer_backends, 0, wx.ALL | wx.EXPAND, 5) | 134 | main_sizer.Add(sizer_backends, 0, wx.ALL | wx.EXPAND, 5) |
110 | main_sizer.Add(self.chk_use_gpu, 0, wx.ALL, 5) | 135 | main_sizer.Add(self.chk_use_gpu, 0, wx.ALL, 5) |
111 | sizer_devices = wx.BoxSizer(wx.HORIZONTAL) | 136 | sizer_devices = wx.BoxSizer(wx.HORIZONTAL) |
112 | - if HAS_PLAIDML: | 137 | + if HAS_TORCH or HAS_PLAIDML: |
113 | sizer_devices.Add(self.lbl_device, 0, wx.ALIGN_CENTER, 0) | 138 | sizer_devices.Add(self.lbl_device, 0, wx.ALIGN_CENTER, 0) |
114 | sizer_devices.Add(self.cb_devices, 1, wx.LEFT, 5) | 139 | sizer_devices.Add(self.cb_devices, 1, wx.LEFT, 5) |
115 | main_sizer.Add(sizer_devices, 0, wx.ALL | wx.EXPAND, 5) | 140 | main_sizer.Add(sizer_devices, 0, wx.ALL | wx.EXPAND, 5) |
@@ -177,8 +202,21 @@ class BrainSegmenterDialog(wx.Dialog): | @@ -177,8 +202,21 @@ class BrainSegmenterDialog(wx.Dialog): | ||
177 | return width, height | 202 | return width, height |
178 | 203 | ||
179 | def OnSetBackend(self, evt=None): | 204 | def OnSetBackend(self, evt=None): |
180 | - if self.cb_backends.GetValue().lower() == "plaidml": | 205 | + if self.cb_backends.GetValue().lower() == "pytorch": |
206 | + if HAS_TORCH: | ||
207 | + choices = list(self.torch_devices.keys()) | ||
208 | + self.cb_devices.Clear() | ||
209 | + self.cb_devices.SetItems(choices) | ||
210 | + self.cb_devices.SetValue(choices[0]) | ||
211 | + self.lbl_device.Show() | ||
212 | + self.cb_devices.Show() | ||
213 | + self.chk_use_gpu.Hide() | ||
214 | + elif self.cb_backends.GetValue().lower() == "plaidml": | ||
181 | if HAS_PLAIDML: | 215 | if HAS_PLAIDML: |
216 | + choices = list(self.plaidml_devices.keys()) | ||
217 | + self.cb_devices.Clear() | ||
218 | + self.cb_devices.SetItems(choices) | ||
219 | + self.cb_devices.SetValue(choices[0]) | ||
182 | self.lbl_device.Show() | 220 | self.lbl_device.Show() |
183 | self.cb_devices.Show() | 221 | self.cb_devices.Show() |
184 | self.chk_use_gpu.Hide() | 222 | self.chk_use_gpu.Hide() |
@@ -216,10 +254,16 @@ class BrainSegmenterDialog(wx.Dialog): | @@ -216,10 +254,16 @@ class BrainSegmenterDialog(wx.Dialog): | ||
216 | self.elapsed_time_timer.Start(1000) | 254 | self.elapsed_time_timer.Start(1000) |
217 | image = slc.Slice().matrix | 255 | image = slc.Slice().matrix |
218 | backend = self.cb_backends.GetValue() | 256 | backend = self.cb_backends.GetValue() |
219 | - try: | ||
220 | - device_id = self.plaidml_devices[self.cb_devices.GetValue()] | ||
221 | - except (KeyError, AttributeError): | ||
222 | - device_id = "llvm_cpu.0" | 257 | + if backend.lower() == "pytorch": |
258 | + try: | ||
259 | + device_id = self.torch_devices[self.cb_devices.GetValue()] | ||
260 | + except (KeyError, AttributeError): | ||
261 | + device_id = "cpu" | ||
262 | + else: | ||
263 | + try: | ||
264 | + device_id = self.plaidml_devices[self.cb_devices.GetValue()] | ||
265 | + except (KeyError, AttributeError): | ||
266 | + device_id = "llvm_cpu.0" | ||
223 | apply_wwwl = self.chk_apply_wwwl.GetValue() | 267 | apply_wwwl = self.chk_apply_wwwl.GetValue() |
224 | create_new_mask = self.chk_new_mask.GetValue() | 268 | create_new_mask = self.chk_new_mask.GetValue() |
225 | use_gpu = self.chk_use_gpu.GetValue() | 269 | use_gpu = self.chk_use_gpu.GetValue() |
invesalius/gui/dialogs.py
@@ -4632,7 +4632,8 @@ class SetNDIconfigs(wx.Dialog): | @@ -4632,7 +4632,8 @@ class SetNDIconfigs(wx.Dialog): | ||
4632 | self._init_gui() | 4632 | self._init_gui() |
4633 | 4633 | ||
4634 | def serial_ports(self): | 4634 | def serial_ports(self): |
4635 | - """ Lists serial port names and pre-select the description containing NDI | 4635 | + """ |
4636 | + Lists serial port names and pre-select the description containing NDI | ||
4636 | """ | 4637 | """ |
4637 | import serial.tools.list_ports | 4638 | import serial.tools.list_ports |
4638 | 4639 | ||
@@ -4748,13 +4749,16 @@ class SetNDIconfigs(wx.Dialog): | @@ -4748,13 +4749,16 @@ class SetNDIconfigs(wx.Dialog): | ||
4748 | return self.com_ports.GetString(self.com_ports.GetSelection()).encode(const.FS_ENCODE), fn_probe, fn_ref, fn_obj | 4749 | return self.com_ports.GetString(self.com_ports.GetSelection()).encode(const.FS_ENCODE), fn_probe, fn_ref, fn_obj |
4749 | 4750 | ||
4750 | 4751 | ||
4751 | -class SetCOMport(wx.Dialog): | ||
4752 | - def __init__(self, title=_("Select COM port")): | ||
4753 | - wx.Dialog.__init__(self, wx.GetApp().GetTopWindow(), -1, title, style=wx.DEFAULT_DIALOG_STYLE|wx.FRAME_FLOAT_ON_PARENT|wx.STAY_ON_TOP) | 4752 | +class SetCOMPort(wx.Dialog): |
4753 | + def __init__(self, select_baud_rate, title=_("Select COM port")): | ||
4754 | + wx.Dialog.__init__(self, wx.GetApp().GetTopWindow(), -1, title, style=wx.DEFAULT_DIALOG_STYLE | wx.FRAME_FLOAT_ON_PARENT | wx.STAY_ON_TOP) | ||
4755 | + | ||
4756 | + self.select_baud_rate = select_baud_rate | ||
4754 | self._init_gui() | 4757 | self._init_gui() |
4755 | 4758 | ||
4756 | def serial_ports(self): | 4759 | def serial_ports(self): |
4757 | - """ Lists serial port names | 4760 | + """ |
4761 | + Lists serial port names | ||
4758 | """ | 4762 | """ |
4759 | import serial.tools.list_ports | 4763 | import serial.tools.list_ports |
4760 | if sys.platform.startswith('win'): | 4764 | if sys.platform.startswith('win'): |
@@ -4764,12 +4768,26 @@ class SetCOMport(wx.Dialog): | @@ -4764,12 +4768,26 @@ class SetCOMport(wx.Dialog): | ||
4764 | return ports | 4768 | return ports |
4765 | 4769 | ||
4766 | def _init_gui(self): | 4770 | def _init_gui(self): |
4767 | - self.com_ports = wx.ComboBox(self, -1, style=wx.CB_DROPDOWN|wx.CB_READONLY) | 4771 | + # COM port selection |
4768 | ports = self.serial_ports() | 4772 | ports = self.serial_ports() |
4769 | - self.com_ports.Append(ports) | 4773 | + self.com_port_dropdown = wx.ComboBox(self, -1, choices=ports, style=wx.CB_DROPDOWN | wx.CB_READONLY) |
4774 | + self.com_port_dropdown.SetSelection(0) | ||
4775 | + | ||
4776 | + com_port_text_and_dropdown = wx.BoxSizer(wx.VERTICAL) | ||
4777 | + com_port_text_and_dropdown.Add(wx.StaticText(self, wx.ID_ANY, "COM port"), 0, wx.TOP | wx.RIGHT,5) | ||
4778 | + com_port_text_and_dropdown.Add(self.com_port_dropdown, 0, wx.EXPAND) | ||
4779 | + | ||
4780 | + # Baud rate selection | ||
4781 | + if self.select_baud_rate: | ||
4782 | + baud_rates_as_strings = [str(baud_rate) for baud_rate in const.BAUD_RATES] | ||
4783 | + self.baud_rate_dropdown = wx.ComboBox(self, -1, choices=baud_rates_as_strings, style=wx.CB_DROPDOWN | wx.CB_READONLY) | ||
4784 | + self.baud_rate_dropdown.SetSelection(const.BAUD_RATE_DEFAULT_SELECTION) | ||
4770 | 4785 | ||
4771 | - # self.goto_orientation.SetSelection(cb_init) | 4786 | + baud_rate_text_and_dropdown = wx.BoxSizer(wx.VERTICAL) |
4787 | + baud_rate_text_and_dropdown.Add(wx.StaticText(self, wx.ID_ANY, "Baud rate"), 0, wx.TOP | wx.RIGHT,5) | ||
4788 | + baud_rate_text_and_dropdown.Add(self.baud_rate_dropdown, 0, wx.EXPAND) | ||
4772 | 4789 | ||
4790 | + # OK and Cancel buttons | ||
4773 | btn_ok = wx.Button(self, wx.ID_OK) | 4791 | btn_ok = wx.Button(self, wx.ID_OK) |
4774 | btn_ok.SetHelpText("") | 4792 | btn_ok.SetHelpText("") |
4775 | btn_ok.SetDefault() | 4793 | btn_ok.SetDefault() |
@@ -4782,10 +4800,16 @@ class SetCOMport(wx.Dialog): | @@ -4782,10 +4800,16 @@ class SetCOMport(wx.Dialog): | ||
4782 | btnsizer.AddButton(btn_cancel) | 4800 | btnsizer.AddButton(btn_cancel) |
4783 | btnsizer.Realize() | 4801 | btnsizer.Realize() |
4784 | 4802 | ||
4803 | + # Set up the main sizer | ||
4785 | main_sizer = wx.BoxSizer(wx.VERTICAL) | 4804 | main_sizer = wx.BoxSizer(wx.VERTICAL) |
4786 | 4805 | ||
4787 | main_sizer.Add((5, 5)) | 4806 | main_sizer.Add((5, 5)) |
4788 | - main_sizer.Add(self.com_ports, 1, wx.EXPAND|wx.LEFT|wx.RIGHT, 5) | 4807 | + main_sizer.Add(com_port_text_and_dropdown, 1, wx.EXPAND | wx.LEFT | wx.RIGHT, 5) |
4808 | + | ||
4809 | + if self.select_baud_rate: | ||
4810 | + main_sizer.Add((5, 5)) | ||
4811 | + main_sizer.Add(baud_rate_text_and_dropdown, 1, wx.EXPAND | wx.LEFT | wx.RIGHT, 5) | ||
4812 | + | ||
4789 | main_sizer.Add((5, 5)) | 4813 | main_sizer.Add((5, 5)) |
4790 | main_sizer.Add(btnsizer, 0, wx.EXPAND) | 4814 | main_sizer.Add(btnsizer, 0, wx.EXPAND) |
4791 | main_sizer.Add((5, 5)) | 4815 | main_sizer.Add((5, 5)) |
@@ -4796,7 +4820,14 @@ class SetCOMport(wx.Dialog): | @@ -4796,7 +4820,14 @@ class SetCOMport(wx.Dialog): | ||
4796 | self.CenterOnParent() | 4820 | self.CenterOnParent() |
4797 | 4821 | ||
4798 | def GetValue(self): | 4822 | def GetValue(self): |
4799 | - return self.com_ports.GetString(self.com_ports.GetSelection()) | 4823 | + com_port = self.com_port_dropdown.GetString(self.com_port_dropdown.GetSelection()) |
4824 | + | ||
4825 | + if self.select_baud_rate: | ||
4826 | + baud_rate = self.baud_rate_dropdown.GetString(self.baud_rate_dropdown.GetSelection()) | ||
4827 | + else: | ||
4828 | + baud_rate = None | ||
4829 | + | ||
4830 | + return com_port, baud_rate | ||
4800 | 4831 | ||
4801 | 4832 | ||
4802 | class ManualWWWLDialog(wx.Dialog): | 4833 | class ManualWWWLDialog(wx.Dialog): |
invesalius/gui/task_navigator.py
@@ -234,8 +234,8 @@ class InnerFoldPanel(wx.Panel): | @@ -234,8 +234,8 @@ class InnerFoldPanel(wx.Panel): | ||
234 | checkcamera.Bind(wx.EVT_CHECKBOX, self.OnVolumeCamera) | 234 | checkcamera.Bind(wx.EVT_CHECKBOX, self.OnVolumeCamera) |
235 | self.checkcamera = checkcamera | 235 | self.checkcamera = checkcamera |
236 | 236 | ||
237 | - # Check box to create markers from serial port | ||
238 | - tooltip = wx.ToolTip(_("Enable serial port communication for creating markers")) | 237 | + # Check box to use serial port to trigger pulse signal and create markers |
238 | + tooltip = wx.ToolTip(_("Enable serial port communication to trigger pulse and create markers")) | ||
239 | checkbox_serial_port = wx.CheckBox(self, -1, _('Serial port')) | 239 | checkbox_serial_port = wx.CheckBox(self, -1, _('Serial port')) |
240 | checkbox_serial_port.SetToolTip(tooltip) | 240 | checkbox_serial_port.SetToolTip(tooltip) |
241 | checkbox_serial_port.SetValue(False) | 241 | checkbox_serial_port.SetValue(False) |
@@ -297,14 +297,20 @@ class InnerFoldPanel(wx.Panel): | @@ -297,14 +297,20 @@ class InnerFoldPanel(wx.Panel): | ||
297 | self.checkobj.Enable(True) | 297 | self.checkobj.Enable(True) |
298 | 298 | ||
299 | def OnEnableSerialPort(self, evt, ctrl): | 299 | def OnEnableSerialPort(self, evt, ctrl): |
300 | - com_port = None | ||
301 | if ctrl.GetValue(): | 300 | if ctrl.GetValue(): |
302 | from wx import ID_OK | 301 | from wx import ID_OK |
303 | - dlg_port = dlg.SetCOMport() | ||
304 | - if dlg_port.ShowModal() == ID_OK: | ||
305 | - com_port = dlg_port.GetValue() | 302 | + dlg_port = dlg.SetCOMPort(select_baud_rate=False) |
306 | 303 | ||
307 | - Publisher.sendMessage('Update serial port', serial_port=com_port) | 304 | + if dlg_port.ShowModal() != ID_OK: |
305 | + ctrl.SetValue(False) | ||
306 | + return | ||
307 | + | ||
308 | + com_port = dlg_port.GetValue() | ||
309 | + baud_rate = 115200 | ||
310 | + | ||
311 | + Publisher.sendMessage('Update serial port', serial_port_in_use=True, com_port=com_port, baud_rate=baud_rate) | ||
312 | + else: | ||
313 | + Publisher.sendMessage('Update serial port', serial_port_in_use=False) | ||
308 | 314 | ||
309 | def OnShowObject(self, evt=None, flag=None, obj_name=None, polydata=None, use_default_object=True): | 315 | def OnShowObject(self, evt=None, flag=None, obj_name=None, polydata=None, use_default_object=True): |
310 | if not evt: | 316 | if not evt: |
@@ -503,7 +509,6 @@ class NeuronavigationPanel(wx.Panel): | @@ -503,7 +509,6 @@ class NeuronavigationPanel(wx.Panel): | ||
503 | Publisher.subscribe(self.LoadImageFiducials, 'Load image fiducials') | 509 | Publisher.subscribe(self.LoadImageFiducials, 'Load image fiducials') |
504 | Publisher.subscribe(self.SetImageFiducial, 'Set image fiducial') | 510 | Publisher.subscribe(self.SetImageFiducial, 'Set image fiducial') |
505 | Publisher.subscribe(self.SetTrackerFiducial, 'Set tracker fiducial') | 511 | Publisher.subscribe(self.SetTrackerFiducial, 'Set tracker fiducial') |
506 | - Publisher.subscribe(self.UpdateSerialPort, 'Update serial port') | ||
507 | Publisher.subscribe(self.UpdateTrackObjectState, 'Update track object state') | 512 | Publisher.subscribe(self.UpdateTrackObjectState, 'Update track object state') |
508 | Publisher.subscribe(self.UpdateImageCoordinates, 'Set cross focal point') | 513 | Publisher.subscribe(self.UpdateImageCoordinates, 'Set cross focal point') |
509 | Publisher.subscribe(self.OnDisconnectTracker, 'Disconnect tracker') | 514 | Publisher.subscribe(self.OnDisconnectTracker, 'Disconnect tracker') |
@@ -627,9 +632,6 @@ class NeuronavigationPanel(wx.Panel): | @@ -627,9 +632,6 @@ class NeuronavigationPanel(wx.Panel): | ||
627 | def UpdateTrackObjectState(self, evt=None, flag=None, obj_name=None, polydata=None, use_default_object=True): | 632 | def UpdateTrackObjectState(self, evt=None, flag=None, obj_name=None, polydata=None, use_default_object=True): |
628 | self.navigation.track_obj = flag | 633 | self.navigation.track_obj = flag |
629 | 634 | ||
630 | - def UpdateSerialPort(self, serial_port): | ||
631 | - self.navigation.serial_port = serial_port | ||
632 | - | ||
633 | def ResetICP(self): | 635 | def ResetICP(self): |
634 | self.icp.ResetICP() | 636 | self.icp.ResetICP() |
635 | self.checkbox_icp.Enable(False) | 637 | self.checkbox_icp.Enable(False) |
invesalius/inv_paths.py
@@ -27,6 +27,7 @@ CONF_DIR = pathlib.Path(os.environ.get("XDG_CONFIG_HOME", USER_DIR.joinpath(".co | @@ -27,6 +27,7 @@ CONF_DIR = pathlib.Path(os.environ.get("XDG_CONFIG_HOME", USER_DIR.joinpath(".co | ||
27 | USER_INV_DIR = CONF_DIR.joinpath("invesalius") | 27 | USER_INV_DIR = CONF_DIR.joinpath("invesalius") |
28 | USER_PRESET_DIR = USER_INV_DIR.joinpath("presets") | 28 | USER_PRESET_DIR = USER_INV_DIR.joinpath("presets") |
29 | USER_LOG_DIR = USER_INV_DIR.joinpath("logs") | 29 | USER_LOG_DIR = USER_INV_DIR.joinpath("logs") |
30 | +USER_DL_WEIGHTS = USER_INV_DIR.joinpath("deep_learning/weights/") | ||
30 | USER_RAYCASTING_PRESETS_DIRECTORY = USER_PRESET_DIR.joinpath("raycasting") | 31 | USER_RAYCASTING_PRESETS_DIRECTORY = USER_PRESET_DIR.joinpath("raycasting") |
31 | TEMP_DIR = tempfile.gettempdir() | 32 | TEMP_DIR = tempfile.gettempdir() |
32 | 33 | ||
@@ -97,6 +98,7 @@ def create_conf_folders(): | @@ -97,6 +98,7 @@ def create_conf_folders(): | ||
97 | USER_INV_DIR.mkdir(parents=True, exist_ok=True) | 98 | USER_INV_DIR.mkdir(parents=True, exist_ok=True) |
98 | USER_PRESET_DIR.mkdir(parents=True, exist_ok=True) | 99 | USER_PRESET_DIR.mkdir(parents=True, exist_ok=True) |
99 | USER_LOG_DIR.mkdir(parents=True, exist_ok=True) | 100 | USER_LOG_DIR.mkdir(parents=True, exist_ok=True) |
101 | + USER_DL_WEIGHTS.mkdir(parents=True, exist_ok=True) | ||
100 | USER_PLUGINS_DIRECTORY.mkdir(parents=True, exist_ok=True) | 102 | USER_PLUGINS_DIRECTORY.mkdir(parents=True, exist_ok=True) |
101 | 103 | ||
102 | 104 |
invesalius/navigation/navigation.py
@@ -171,7 +171,9 @@ class Navigation(): | @@ -171,7 +171,9 @@ class Navigation(): | ||
171 | self.sleep_nav = const.SLEEP_NAVIGATION | 171 | self.sleep_nav = const.SLEEP_NAVIGATION |
172 | 172 | ||
173 | # Serial port | 173 | # Serial port |
174 | - self.serial_port = None | 174 | + self.serial_port_in_use = False |
175 | + self.com_port = None | ||
176 | + self.baud_rate = None | ||
175 | self.serial_port_connection = None | 177 | self.serial_port_connection = None |
176 | 178 | ||
177 | # During navigation | 179 | # During navigation |
@@ -181,6 +183,7 @@ class Navigation(): | @@ -181,6 +183,7 @@ class Navigation(): | ||
181 | 183 | ||
182 | def __bind_events(self): | 184 | def __bind_events(self): |
183 | Publisher.subscribe(self.CoilAtTarget, 'Coil at target') | 185 | Publisher.subscribe(self.CoilAtTarget, 'Coil at target') |
186 | + Publisher.subscribe(self.UpdateSerialPort, 'Update serial port') | ||
184 | 187 | ||
185 | def CoilAtTarget(self, state): | 188 | def CoilAtTarget(self, state): |
186 | self.coil_at_target = state | 189 | self.coil_at_target = state |
@@ -189,8 +192,10 @@ class Navigation(): | @@ -189,8 +192,10 @@ class Navigation(): | ||
189 | self.sleep_nav = sleep | 192 | self.sleep_nav = sleep |
190 | self.serial_port_connection.sleep_nav = sleep | 193 | self.serial_port_connection.sleep_nav = sleep |
191 | 194 | ||
192 | - def SerialPortEnabled(self): | ||
193 | - return self.serial_port is not None | 195 | + def UpdateSerialPort(self, serial_port_in_use, com_port=None, baud_rate=None): |
196 | + self.serial_port_in_use = serial_port_in_use | ||
197 | + self.com_port = com_port | ||
198 | + self.baud_rate = baud_rate | ||
194 | 199 | ||
195 | def SetReferenceMode(self, value): | 200 | def SetReferenceMode(self, value): |
196 | self.ref_mode_id = value | 201 | self.ref_mode_id = value |
@@ -218,7 +223,7 @@ class Navigation(): | @@ -218,7 +223,7 @@ class Navigation(): | ||
218 | return fre, fre <= const.FIDUCIAL_REGISTRATION_ERROR_THRESHOLD | 223 | return fre, fre <= const.FIDUCIAL_REGISTRATION_ERROR_THRESHOLD |
219 | 224 | ||
220 | def PedalStateChanged(self, state): | 225 | def PedalStateChanged(self, state): |
221 | - if state is True and self.coil_at_target and self.SerialPortEnabled(): | 226 | + if state is True and self.coil_at_target and self.serial_port_in_use: |
222 | self.serial_port_connection.SendPulse() | 227 | self.serial_port_connection.SendPulse() |
223 | 228 | ||
224 | def StartNavigation(self, tracker): | 229 | def StartNavigation(self, tracker): |
@@ -230,7 +235,7 @@ class Navigation(): | @@ -230,7 +235,7 @@ class Navigation(): | ||
230 | if self.event.is_set(): | 235 | if self.event.is_set(): |
231 | self.event.clear() | 236 | self.event.clear() |
232 | 237 | ||
233 | - vis_components = [self.SerialPortEnabled(), self.view_tracts, self.peel_loaded] | 238 | + vis_components = [self.serial_port_in_use, self.view_tracts, self.peel_loaded] |
234 | vis_queues = [self.coord_queue, self.serial_port_queue, self.tracts_queue, self.icp_queue, self.robottarget_queue] | 239 | vis_queues = [self.coord_queue, self.serial_port_queue, self.tracts_queue, self.icp_queue, self.robottarget_queue] |
235 | 240 | ||
236 | Publisher.sendMessage("Navigation status", nav_status=True, vis_status=vis_components) | 241 | Publisher.sendMessage("Navigation status", nav_status=True, vis_status=vis_components) |
@@ -279,12 +284,13 @@ class Navigation(): | @@ -279,12 +284,13 @@ class Navigation(): | ||
279 | 284 | ||
280 | if not errors: | 285 | if not errors: |
281 | #TODO: Test the serial port thread | 286 | #TODO: Test the serial port thread |
282 | - if self.SerialPortEnabled(): | 287 | + if self.serial_port_in_use: |
283 | self.serial_port_connection = spc.SerialPortConnection( | 288 | self.serial_port_connection = spc.SerialPortConnection( |
284 | - self.serial_port, | ||
285 | - self.serial_port_queue, | ||
286 | - self.event, | ||
287 | - self.sleep_nav, | 289 | + com_port=self.com_port, |
290 | + baud_rate=self.baud_rate, | ||
291 | + serial_port_queue=self.serial_port_queue, | ||
292 | + event=self.event, | ||
293 | + sleep_nav=self.sleep_nav, | ||
288 | ) | 294 | ) |
289 | self.serial_port_connection.Connect() | 295 | self.serial_port_connection.Connect() |
290 | jobs_list.append(self.serial_port_connection) | 296 | jobs_list.append(self.serial_port_connection) |
@@ -330,7 +336,7 @@ class Navigation(): | @@ -330,7 +336,7 @@ class Navigation(): | ||
330 | if self.serial_port_connection is not None: | 336 | if self.serial_port_connection is not None: |
331 | self.serial_port_connection.join() | 337 | self.serial_port_connection.join() |
332 | 338 | ||
333 | - if self.SerialPortEnabled(): | 339 | + if self.serial_port_in_use: |
334 | self.serial_port_queue.clear() | 340 | self.serial_port_queue.clear() |
335 | self.serial_port_queue.join() | 341 | self.serial_port_queue.join() |
336 | 342 | ||
@@ -341,5 +347,5 @@ class Navigation(): | @@ -341,5 +347,5 @@ class Navigation(): | ||
341 | self.tracts_queue.clear() | 347 | self.tracts_queue.clear() |
342 | self.tracts_queue.join() | 348 | self.tracts_queue.join() |
343 | 349 | ||
344 | - vis_components = [self.SerialPortEnabled(), self.view_tracts, self.peel_loaded] | 350 | + vis_components = [self.serial_port_in_use, self.view_tracts, self.peel_loaded] |
345 | Publisher.sendMessage("Navigation status", nav_status=False, vis_status=vis_components) | 351 | Publisher.sendMessage("Navigation status", nav_status=False, vis_status=vis_components) |
@@ -0,0 +1,48 @@ | @@ -0,0 +1,48 @@ | ||
1 | +from urllib.error import HTTPError | ||
2 | +from urllib.request import urlopen, Request | ||
3 | +from urllib.parse import urlparse | ||
4 | +import pathlib | ||
5 | +import tempfile | ||
6 | +import typing | ||
7 | +import hashlib | ||
8 | +import os | ||
9 | +import shutil | ||
10 | + | ||
11 | +def download_url_to_file(url: str, dst: pathlib.Path, hash: str = None, callback: typing.Callable[[float], None] = None): | ||
12 | + file_size = None | ||
13 | + total_downloaded = 0 | ||
14 | + if hash is not None: | ||
15 | + calc_hash = hashlib.sha256() | ||
16 | + req = Request(url) | ||
17 | + response = urlopen(req) | ||
18 | + meta = response.info() | ||
19 | + if hasattr(meta, "getheaders"): | ||
20 | + content_length = meta.getheaders("Content-Length") | ||
21 | + else: | ||
22 | + content_length = meta.get_all("Content-Length") | ||
23 | + | ||
24 | + if content_length is not None and len(content_length) > 0: | ||
25 | + file_size = int(content_length[0]) | ||
26 | + dst.parent.mkdir(parents=True, exist_ok=True) | ||
27 | + f = tempfile.NamedTemporaryFile(delete=False, dir=dst.parent) | ||
28 | + try: | ||
29 | + while True: | ||
30 | + buffer = response.read(8192) | ||
31 | + if len(buffer) == 0: | ||
32 | + break | ||
33 | + total_downloaded += len(buffer) | ||
34 | + f.write(buffer) | ||
35 | + if hash: | ||
36 | + calc_hash.update(buffer) | ||
37 | + if callback is not None: | ||
38 | + callback(100 * total_downloaded/file_size) | ||
39 | + f.close() | ||
40 | + if hash is not None: | ||
41 | + digest = calc_hash.hexdigest() | ||
42 | + if digest != hash: | ||
43 | + raise RuntimeError(f'Invalid hash value (expected "{hash}", got "{digest}")') | ||
44 | + shutil.move(f.name, dst) | ||
45 | + finally: | ||
46 | + f.close() | ||
47 | + if os.path.exists(f.name): | ||
48 | + os.remove(f.name) |
@@ -0,0 +1,149 @@ | @@ -0,0 +1,149 @@ | ||
1 | +from collections import OrderedDict | ||
2 | + | ||
3 | +import torch | ||
4 | +import torch.nn as nn | ||
5 | + | ||
6 | +SIZE = 48 | ||
7 | + | ||
8 | +class Unet3D(nn.Module): | ||
9 | + # Based on https://github.com/mateuszbuda/brain-segmentation-pytorch/blob/master/unet.py | ||
10 | + def __init__(self, in_channels=1, out_channels=1, init_features=8): | ||
11 | + super().__init__() | ||
12 | + features = init_features | ||
13 | + | ||
14 | + self.encoder1 = self._block( | ||
15 | + in_channels, features=features, padding=2, name="enc1" | ||
16 | + ) | ||
17 | + self.pool1 = nn.MaxPool3d(kernel_size=2, stride=2) | ||
18 | + | ||
19 | + self.encoder2 = self._block( | ||
20 | + features, features=features * 2, padding=2, name="enc2" | ||
21 | + ) | ||
22 | + self.pool2 = nn.MaxPool3d(kernel_size=2, stride=2) | ||
23 | + | ||
24 | + self.encoder3 = self._block( | ||
25 | + features * 2, features=features * 4, padding=2, name="enc3" | ||
26 | + ) | ||
27 | + self.pool3 = nn.MaxPool3d(kernel_size=2, stride=2) | ||
28 | + | ||
29 | + self.encoder4 = self._block( | ||
30 | + features * 4, features=features * 8, padding=2, name="enc4" | ||
31 | + ) | ||
32 | + self.pool4 = nn.MaxPool3d(kernel_size=2, stride=2) | ||
33 | + | ||
34 | + self.bottleneck = self._block( | ||
35 | + features * 8, features=features * 16, padding=2, name="bottleneck" | ||
36 | + ) | ||
37 | + | ||
38 | + self.upconv4 = nn.ConvTranspose3d( | ||
39 | + features * 16, features * 8, kernel_size=4, stride=2, padding=1 | ||
40 | + ) | ||
41 | + self.decoder4 = self._block( | ||
42 | + features * 16, features=features * 8, padding=2, name="dec4" | ||
43 | + ) | ||
44 | + | ||
45 | + self.upconv3 = nn.ConvTranspose3d( | ||
46 | + features * 8, features * 4, kernel_size=4, stride=2, padding=1 | ||
47 | + ) | ||
48 | + self.decoder3 = self._block( | ||
49 | + features * 8, features=features * 4, padding=2, name="dec4" | ||
50 | + ) | ||
51 | + | ||
52 | + self.upconv2 = nn.ConvTranspose3d( | ||
53 | + features * 4, features * 2, kernel_size=4, stride=2, padding=1 | ||
54 | + ) | ||
55 | + self.decoder2 = self._block( | ||
56 | + features * 4, features=features * 2, padding=2, name="dec4" | ||
57 | + ) | ||
58 | + | ||
59 | + self.upconv1 = nn.ConvTranspose3d( | ||
60 | + features * 2, features, kernel_size=4, stride=2, padding=1 | ||
61 | + ) | ||
62 | + self.decoder1 = self._block( | ||
63 | + features * 2, features=features, padding=2, name="dec4" | ||
64 | + ) | ||
65 | + | ||
66 | + self.conv = nn.Conv3d( | ||
67 | + in_channels=features, out_channels=out_channels, kernel_size=1 | ||
68 | + ) | ||
69 | + | ||
70 | + def forward(self, img): | ||
71 | + enc1 = self.encoder1(img) | ||
72 | + enc2 = self.encoder2(self.pool1(enc1)) | ||
73 | + enc3 = self.encoder3(self.pool2(enc2)) | ||
74 | + enc4 = self.encoder4(self.pool3(enc3)) | ||
75 | + | ||
76 | + bottleneck = self.bottleneck(self.pool4(enc4)) | ||
77 | + | ||
78 | + upconv4 = self.upconv4(bottleneck) | ||
79 | + dec4 = torch.cat((upconv4, enc4), dim=1) | ||
80 | + dec4 = self.decoder4(dec4) | ||
81 | + | ||
82 | + upconv3 = self.upconv3(dec4) | ||
83 | + dec3 = torch.cat((upconv3, enc3), dim=1) | ||
84 | + dec3 = self.decoder3(dec3) | ||
85 | + | ||
86 | + upconv2 = self.upconv2(dec3) | ||
87 | + dec2 = torch.cat((upconv2, enc2), dim=1) | ||
88 | + dec2 = self.decoder2(dec2) | ||
89 | + | ||
90 | + upconv1 = self.upconv1(dec2) | ||
91 | + dec1 = torch.cat((upconv1, enc1), dim=1) | ||
92 | + dec1 = self.decoder1(dec1) | ||
93 | + | ||
94 | + conv = self.conv(dec1) | ||
95 | + | ||
96 | + sigmoid = torch.sigmoid(conv) | ||
97 | + | ||
98 | + return sigmoid | ||
99 | + | ||
100 | + def _block(self, in_channels, features, padding=1, kernel_size=5, name="block"): | ||
101 | + return nn.Sequential( | ||
102 | + OrderedDict( | ||
103 | + ( | ||
104 | + ( | ||
105 | + f"{name}_conv1", | ||
106 | + nn.Conv3d( | ||
107 | + in_channels=in_channels, | ||
108 | + out_channels=features, | ||
109 | + kernel_size=kernel_size, | ||
110 | + padding=padding, | ||
111 | + bias=True, | ||
112 | + ), | ||
113 | + ), | ||
114 | + (f"{name}_norm1", nn.BatchNorm3d(num_features=features)), | ||
115 | + (f"{name}_relu1", nn.ReLU(inplace=True)), | ||
116 | + ( | ||
117 | + f"{name}_conv2", | ||
118 | + nn.Conv3d( | ||
119 | + in_channels=features, | ||
120 | + out_channels=features, | ||
121 | + kernel_size=kernel_size, | ||
122 | + padding=padding, | ||
123 | + bias=True, | ||
124 | + ), | ||
125 | + ), | ||
126 | + (f"{name}_norm2", nn.BatchNorm3d(num_features=features)), | ||
127 | + (f"{name}_relu2", nn.ReLU(inplace=True)), | ||
128 | + ) | ||
129 | + ) | ||
130 | + ) | ||
131 | + | ||
132 | + | ||
133 | +def main(): | ||
134 | + import torchviz | ||
135 | + dev = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | ||
136 | + model = Unet3D() | ||
137 | + model.to(dev) | ||
138 | + model.eval() | ||
139 | + print(next(model.parameters()).is_cuda) # True | ||
140 | + img = torch.randn(1, SIZE, SIZE, SIZE, 1).to(dev) | ||
141 | + out = model(img) | ||
142 | + dot = torchviz.make_dot(out, params=dict(model.named_parameters()), show_attrs=True, show_saved=True) | ||
143 | + dot.render("unet", format="png") | ||
144 | + torch.save(model, "model.pth") | ||
145 | + print(dot) | ||
146 | + | ||
147 | + | ||
148 | +if __name__ == "__main__": | ||
149 | + main() |
invesalius/segmentation/brain/segment.py
@@ -13,6 +13,8 @@ import invesalius.data.slice_ as slc | @@ -13,6 +13,8 @@ import invesalius.data.slice_ as slc | ||
13 | from invesalius import inv_paths | 13 | from invesalius import inv_paths |
14 | from invesalius.data import imagedata_utils | 14 | from invesalius.data import imagedata_utils |
15 | from invesalius.utils import new_name_by_pattern | 15 | from invesalius.utils import new_name_by_pattern |
16 | +from invesalius.net.utils import download_url_to_file | ||
17 | +from invesalius import inv_paths | ||
16 | 18 | ||
17 | from . import utils | 19 | from . import utils |
18 | 20 | ||
@@ -64,6 +66,17 @@ def predict_patch(sub_image, patch, nn_model, patch_size=SIZE): | @@ -64,6 +66,17 @@ def predict_patch(sub_image, patch, nn_model, patch_size=SIZE): | ||
64 | 0 : ez - iz, 0 : ey - iy, 0 : ex - ix | 66 | 0 : ez - iz, 0 : ey - iy, 0 : ex - ix |
65 | ] | 67 | ] |
66 | 68 | ||
69 | +def predict_patch_torch(sub_image, patch, nn_model, device, patch_size=SIZE): | ||
70 | + import torch | ||
71 | + with torch.no_grad(): | ||
72 | + (iz, ez), (iy, ey), (ix, ex) = patch | ||
73 | + sub_mask = nn_model( | ||
74 | + torch.from_numpy(sub_image.reshape(1, 1, patch_size, patch_size, patch_size)).to(device) | ||
75 | + ).cpu().numpy() | ||
76 | + return sub_mask.reshape(patch_size, patch_size, patch_size)[ | ||
77 | + 0 : ez - iz, 0 : ey - iy, 0 : ex - ix | ||
78 | + ] | ||
79 | + | ||
67 | 80 | ||
68 | def brain_segment(image, probability_array, comm_array): | 81 | def brain_segment(image, probability_array, comm_array): |
69 | import keras | 82 | import keras |
@@ -89,6 +102,42 @@ def brain_segment(image, probability_array, comm_array): | @@ -89,6 +102,42 @@ def brain_segment(image, probability_array, comm_array): | ||
89 | comm_array[0] = np.Inf | 102 | comm_array[0] = np.Inf |
90 | 103 | ||
91 | 104 | ||
105 | +def download_callback(comm_array): | ||
106 | + def _download_callback(value): | ||
107 | + comm_array[0] = value | ||
108 | + return _download_callback | ||
109 | + | ||
110 | +def brain_segment_torch(image, device_id, probability_array, comm_array): | ||
111 | + import torch | ||
112 | + from .model import Unet3D | ||
113 | + device = torch.device(device_id) | ||
114 | + state_dict_file = inv_paths.USER_DL_WEIGHTS.joinpath("brain_mri_t1.pt") | ||
115 | + if not state_dict_file.exists(): | ||
116 | + download_url_to_file( | ||
117 | + "https://github.com/tfmoraes/deepbrain_torch/releases/download/v1.1.0/weights.pt", | ||
118 | + state_dict_file, | ||
119 | + "194b0305947c9326eeee9da34ada728435a13c7b24015cbd95971097fc178f22", | ||
120 | + download_callback(comm_array) | ||
121 | + ) | ||
122 | + state_dict = torch.load(str(state_dict_file)) | ||
123 | + model = Unet3D() | ||
124 | + model.load_state_dict(state_dict["model_state_dict"]) | ||
125 | + model.to(device) | ||
126 | + model.eval() | ||
127 | + | ||
128 | + image = imagedata_utils.image_normalize(image, 0.0, 1.0, output_dtype=np.float32) | ||
129 | + sums = np.zeros_like(image) | ||
130 | + # segmenting by patches | ||
131 | + for completion, sub_image, patch in gen_patches(image, SIZE, OVERLAP): | ||
132 | + comm_array[0] = completion | ||
133 | + (iz, ez), (iy, ey), (ix, ex) = patch | ||
134 | + sub_mask = predict_patch_torch(sub_image, patch, model, device, SIZE) | ||
135 | + probability_array[iz:ez, iy:ey, ix:ex] += sub_mask | ||
136 | + sums[iz:ez, iy:ey, ix:ex] += 1 | ||
137 | + | ||
138 | + probability_array /= sums | ||
139 | + comm_array[0] = np.Inf | ||
140 | + | ||
92 | ctx = multiprocessing.get_context('spawn') | 141 | ctx = multiprocessing.get_context('spawn') |
93 | class SegmentProcess(ctx.Process): | 142 | class SegmentProcess(ctx.Process): |
94 | def __init__(self, image, create_new_mask, backend, device_id, use_gpu, apply_wwwl=False, window_width=255, window_level=127): | 143 | def __init__(self, image, create_new_mask, backend, device_id, use_gpu, apply_wwwl=False, window_width=255, window_level=127): |
@@ -138,8 +187,7 @@ class SegmentProcess(ctx.Process): | @@ -138,8 +187,7 @@ class SegmentProcess(ctx.Process): | ||
138 | mode="r", | 187 | mode="r", |
139 | ) | 188 | ) |
140 | 189 | ||
141 | - print(image.min(), image.max()) | ||
142 | - if self.apply_segment_threshold: | 190 | + if self.apply_wwwl: |
143 | print("Applying window level") | 191 | print("Applying window level") |
144 | image = get_LUT_value(image, self.window_width, self.window_level) | 192 | image = get_LUT_value(image, self.window_width, self.window_level) |
145 | 193 | ||
@@ -153,8 +201,11 @@ class SegmentProcess(ctx.Process): | @@ -153,8 +201,11 @@ class SegmentProcess(ctx.Process): | ||
153 | self._comm_array_filename, dtype=np.float32, shape=(1,), mode="r+" | 201 | self._comm_array_filename, dtype=np.float32, shape=(1,), mode="r+" |
154 | ) | 202 | ) |
155 | 203 | ||
156 | - utils.prepare_ambient(self.backend, self.device_id, self.use_gpu) | ||
157 | - brain_segment(image, probability_array, comm_array) | 204 | + if self.backend.lower() == "pytorch": |
205 | + brain_segment_torch(image, self.device_id, probability_array, comm_array) | ||
206 | + else: | ||
207 | + utils.prepare_ambient(self.backend, self.device_id, self.use_gpu) | ||
208 | + brain_segment(image, probability_array, comm_array) | ||
158 | 209 | ||
159 | @property | 210 | @property |
160 | def exception(self): | 211 | def exception(self): |
optional-requirements.txt
requirements.txt
@@ -0,0 +1,88 @@ | @@ -0,0 +1,88 @@ | ||
1 | +#!/usr/bin/env python3 | ||
2 | +# -*- coding: utf-8 -*- | ||
3 | + | ||
4 | +# This scripts allows sending events to InVesalius via Socket.IO, mimicking InVesalius's | ||
5 | +# internal communication. It can be useful for developing and debugging InVesalius. | ||
6 | +# | ||
7 | +# Example usage: | ||
8 | +# | ||
9 | +# - (In console window 1) Run the script by: python scripts/invesalius_server.py 5000 | ||
10 | +# | ||
11 | +# - (In console window 2) Run InVesalius by: python app.py --remote-host http://localhost:5000 | ||
12 | +# | ||
13 | +# - If InVesalius connected to the server successfully, a message should appear in console window 1, | ||
14 | +# asking to provide the topic name. | ||
15 | +# | ||
16 | +# - Enter the topic name, such as "Add marker" (without quotes). | ||
17 | +# | ||
18 | +# - Enter the data, such as {"ball_id": 0, "size": 2, "colour": [1.0, 1.0, 0.0], "coord": [10.0, 20.0, 30.0]} | ||
19 | +# | ||
20 | +# - If successful, a message should now appear in console window 2, indicating that the event was received. | ||
21 | + | ||
22 | +import asyncio | ||
23 | +import sys | ||
24 | +import json | ||
25 | + | ||
26 | +import aioconsole | ||
27 | +import nest_asyncio | ||
28 | +import socketio | ||
29 | +import uvicorn | ||
30 | + | ||
31 | +nest_asyncio.apply() | ||
32 | + | ||
33 | +if len(sys.argv) != 2: | ||
34 | + print ("""This script allows sending events to InVesalius. | ||
35 | + | ||
36 | +Usage: python invesalius_server.py port""") | ||
37 | + sys.exit(1) | ||
38 | + | ||
39 | +port = int(sys.argv[1]) | ||
40 | + | ||
41 | +sio = socketio.AsyncServer(async_mode='asgi') | ||
42 | +app = socketio.ASGIApp(sio) | ||
43 | + | ||
44 | +connected = False | ||
45 | + | ||
46 | +@sio.event | ||
47 | +def connect(sid, environ): | ||
48 | + global connected | ||
49 | + connected = True | ||
50 | + | ||
51 | +def print_json_error(e): | ||
52 | + print("Invalid JSON") | ||
53 | + print(e.doc) | ||
54 | + print(" " * e.pos + "^") | ||
55 | + print(e.msg) | ||
56 | + print("") | ||
57 | + | ||
58 | +async def run(): | ||
59 | + while True: | ||
60 | + if not connected: | ||
61 | + await asyncio.sleep(1) | ||
62 | + continue | ||
63 | + | ||
64 | + print("Enter topic: ") | ||
65 | + topic = await aioconsole.ainput() | ||
66 | + print("Enter data as JSON: ") | ||
67 | + data = await aioconsole.ainput() | ||
68 | + | ||
69 | + try: | ||
70 | + decoded = json.loads(data) | ||
71 | + except json.decoder.JSONDecodeError as e: | ||
72 | + print_json_error(e) | ||
73 | + continue | ||
74 | + | ||
75 | + await sio.emit( | ||
76 | + event="to_neuronavigation", | ||
77 | + data={ | ||
78 | + "topic": topic, | ||
79 | + "data": decoded, | ||
80 | + } | ||
81 | + ) | ||
82 | + | ||
83 | +async def main(): | ||
84 | + asyncio.create_task(run()) | ||
85 | + uvicorn.run(app, port=port, host='0.0.0.0', loop='asyncio') | ||
86 | + | ||
87 | +if __name__ == '__main__': | ||
88 | + asyncio.run(main(), debug=True) |