Commit a406f61e7cce53232cda3932c8981e993deb61fb
Committed by
GitHub
1 parent
9460b9cd
Exists in
master
Add pytorch backend to brain segmentation (#365)
* added pytorch to requirements.txt and showing in GUI * segmenting using pytorch * Remove debug prints * Added an function to download a file and save it localy * Downloading the weight using the developed function to download file
Showing
6 changed files
with
308 additions
and
13 deletions
Show diff stats
invesalius/gui/brain_seg_dialog.py
@@ -22,6 +22,22 @@ HAS_THEANO = bool(importlib.util.find_spec("theano")) | @@ -22,6 +22,22 @@ HAS_THEANO = bool(importlib.util.find_spec("theano")) | ||
22 | HAS_PLAIDML = bool(importlib.util.find_spec("plaidml")) | 22 | HAS_PLAIDML = bool(importlib.util.find_spec("plaidml")) |
23 | PLAIDML_DEVICES = {} | 23 | PLAIDML_DEVICES = {} |
24 | 24 | ||
25 | +try: | ||
26 | + import torch | ||
27 | + HAS_TORCH = True | ||
28 | +except ImportError: | ||
29 | + HAS_TORCH = False | ||
30 | + | ||
31 | +if HAS_TORCH: | ||
32 | + TORCH_DEVICES = {} | ||
33 | + if torch.cuda.is_available(): | ||
34 | + for i in range(torch.cuda.device_count()): | ||
35 | + name = torch.cuda.get_device_name() | ||
36 | + device_id = f'cuda:{i}' | ||
37 | + TORCH_DEVICES[name] = device_id | ||
38 | + TORCH_DEVICES['CPU'] = 'cpu' | ||
39 | + | ||
40 | + | ||
25 | 41 | ||
26 | if HAS_PLAIDML: | 42 | if HAS_PLAIDML: |
27 | with multiprocessing.Pool(1) as p: | 43 | with multiprocessing.Pool(1) as p: |
@@ -43,12 +59,15 @@ class BrainSegmenterDialog(wx.Dialog): | @@ -43,12 +59,15 @@ class BrainSegmenterDialog(wx.Dialog): | ||
43 | style=wx.DEFAULT_DIALOG_STYLE | wx.FRAME_FLOAT_ON_PARENT, | 59 | style=wx.DEFAULT_DIALOG_STYLE | wx.FRAME_FLOAT_ON_PARENT, |
44 | ) | 60 | ) |
45 | backends = [] | 61 | backends = [] |
62 | + if HAS_TORCH: | ||
63 | + backends.append("Pytorch") | ||
46 | if HAS_PLAIDML: | 64 | if HAS_PLAIDML: |
47 | backends.append("PlaidML") | 65 | backends.append("PlaidML") |
48 | if HAS_THEANO: | 66 | if HAS_THEANO: |
49 | backends.append("Theano") | 67 | backends.append("Theano") |
50 | # self.segmenter = segment.BrainSegmenter() | 68 | # self.segmenter = segment.BrainSegmenter() |
51 | # self.pg_dialog = None | 69 | # self.pg_dialog = None |
70 | + self.torch_devices = TORCH_DEVICES | ||
52 | self.plaidml_devices = PLAIDML_DEVICES | 71 | self.plaidml_devices = PLAIDML_DEVICES |
53 | 72 | ||
54 | self.ps = None | 73 | self.ps = None |
@@ -65,13 +84,19 @@ class BrainSegmenterDialog(wx.Dialog): | @@ -65,13 +84,19 @@ class BrainSegmenterDialog(wx.Dialog): | ||
65 | w, h = self.CalcSizeFromTextSize("MM" * (1 + max(len(i) for i in backends))) | 84 | w, h = self.CalcSizeFromTextSize("MM" * (1 + max(len(i) for i in backends))) |
66 | self.cb_backends.SetMinClientSize((w, -1)) | 85 | self.cb_backends.SetMinClientSize((w, -1)) |
67 | self.chk_use_gpu = wx.CheckBox(self, wx.ID_ANY, _("Use GPU")) | 86 | self.chk_use_gpu = wx.CheckBox(self, wx.ID_ANY, _("Use GPU")) |
68 | - if HAS_PLAIDML: | 87 | + if HAS_TORCH or HAS_PLAIDML: |
88 | + if HAS_TORCH: | ||
89 | + choices = list(self.torch_devices.keys()) | ||
90 | + value = choices[0] | ||
91 | + else: | ||
92 | + choices = list(self.plaidml_devices.keys()) | ||
93 | + value = choices[0] | ||
69 | self.lbl_device = wx.StaticText(self, -1, _("Device")) | 94 | self.lbl_device = wx.StaticText(self, -1, _("Device")) |
70 | self.cb_devices = wx.ComboBox( | 95 | self.cb_devices = wx.ComboBox( |
71 | self, | 96 | self, |
72 | wx.ID_ANY, | 97 | wx.ID_ANY, |
73 | - choices=list(self.plaidml_devices.keys()), | ||
74 | - value=list(self.plaidml_devices.keys())[0], | 98 | + choices=choices, |
99 | + value=value, | ||
75 | style=wx.CB_DROPDOWN | wx.CB_READONLY, | 100 | style=wx.CB_DROPDOWN | wx.CB_READONLY, |
76 | ) | 101 | ) |
77 | self.sld_threshold = wx.Slider(self, wx.ID_ANY, 75, 0, 100) | 102 | self.sld_threshold = wx.Slider(self, wx.ID_ANY, 75, 0, 100) |
@@ -109,7 +134,7 @@ class BrainSegmenterDialog(wx.Dialog): | @@ -109,7 +134,7 @@ class BrainSegmenterDialog(wx.Dialog): | ||
109 | main_sizer.Add(sizer_backends, 0, wx.ALL | wx.EXPAND, 5) | 134 | main_sizer.Add(sizer_backends, 0, wx.ALL | wx.EXPAND, 5) |
110 | main_sizer.Add(self.chk_use_gpu, 0, wx.ALL, 5) | 135 | main_sizer.Add(self.chk_use_gpu, 0, wx.ALL, 5) |
111 | sizer_devices = wx.BoxSizer(wx.HORIZONTAL) | 136 | sizer_devices = wx.BoxSizer(wx.HORIZONTAL) |
112 | - if HAS_PLAIDML: | 137 | + if HAS_TORCH or HAS_PLAIDML: |
113 | sizer_devices.Add(self.lbl_device, 0, wx.ALIGN_CENTER, 0) | 138 | sizer_devices.Add(self.lbl_device, 0, wx.ALIGN_CENTER, 0) |
114 | sizer_devices.Add(self.cb_devices, 1, wx.LEFT, 5) | 139 | sizer_devices.Add(self.cb_devices, 1, wx.LEFT, 5) |
115 | main_sizer.Add(sizer_devices, 0, wx.ALL | wx.EXPAND, 5) | 140 | main_sizer.Add(sizer_devices, 0, wx.ALL | wx.EXPAND, 5) |
@@ -177,8 +202,21 @@ class BrainSegmenterDialog(wx.Dialog): | @@ -177,8 +202,21 @@ class BrainSegmenterDialog(wx.Dialog): | ||
177 | return width, height | 202 | return width, height |
178 | 203 | ||
179 | def OnSetBackend(self, evt=None): | 204 | def OnSetBackend(self, evt=None): |
180 | - if self.cb_backends.GetValue().lower() == "plaidml": | 205 | + if self.cb_backends.GetValue().lower() == "pytorch": |
206 | + if HAS_TORCH: | ||
207 | + choices = list(self.torch_devices.keys()) | ||
208 | + self.cb_devices.Clear() | ||
209 | + self.cb_devices.SetItems(choices) | ||
210 | + self.cb_devices.SetValue(choices[0]) | ||
211 | + self.lbl_device.Show() | ||
212 | + self.cb_devices.Show() | ||
213 | + self.chk_use_gpu.Hide() | ||
214 | + elif self.cb_backends.GetValue().lower() == "plaidml": | ||
181 | if HAS_PLAIDML: | 215 | if HAS_PLAIDML: |
216 | + choices = list(self.plaidml_devices.keys()) | ||
217 | + self.cb_devices.Clear() | ||
218 | + self.cb_devices.SetItems(choices) | ||
219 | + self.cb_devices.SetValue(choices[0]) | ||
182 | self.lbl_device.Show() | 220 | self.lbl_device.Show() |
183 | self.cb_devices.Show() | 221 | self.cb_devices.Show() |
184 | self.chk_use_gpu.Hide() | 222 | self.chk_use_gpu.Hide() |
@@ -216,10 +254,16 @@ class BrainSegmenterDialog(wx.Dialog): | @@ -216,10 +254,16 @@ class BrainSegmenterDialog(wx.Dialog): | ||
216 | self.elapsed_time_timer.Start(1000) | 254 | self.elapsed_time_timer.Start(1000) |
217 | image = slc.Slice().matrix | 255 | image = slc.Slice().matrix |
218 | backend = self.cb_backends.GetValue() | 256 | backend = self.cb_backends.GetValue() |
219 | - try: | ||
220 | - device_id = self.plaidml_devices[self.cb_devices.GetValue()] | ||
221 | - except (KeyError, AttributeError): | ||
222 | - device_id = "llvm_cpu.0" | 257 | + if backend.lower() == "pytorch": |
258 | + try: | ||
259 | + device_id = self.torch_devices[self.cb_devices.GetValue()] | ||
260 | + except (KeyError, AttributeError): | ||
261 | + device_id = "cpu" | ||
262 | + else: | ||
263 | + try: | ||
264 | + device_id = self.plaidml_devices[self.cb_devices.GetValue()] | ||
265 | + except (KeyError, AttributeError): | ||
266 | + device_id = "llvm_cpu.0" | ||
223 | apply_wwwl = self.chk_apply_wwwl.GetValue() | 267 | apply_wwwl = self.chk_apply_wwwl.GetValue() |
224 | create_new_mask = self.chk_new_mask.GetValue() | 268 | create_new_mask = self.chk_new_mask.GetValue() |
225 | use_gpu = self.chk_use_gpu.GetValue() | 269 | use_gpu = self.chk_use_gpu.GetValue() |
invesalius/inv_paths.py
@@ -27,6 +27,7 @@ CONF_DIR = pathlib.Path(os.environ.get("XDG_CONFIG_HOME", USER_DIR.joinpath(".co | @@ -27,6 +27,7 @@ CONF_DIR = pathlib.Path(os.environ.get("XDG_CONFIG_HOME", USER_DIR.joinpath(".co | ||
27 | USER_INV_DIR = CONF_DIR.joinpath("invesalius") | 27 | USER_INV_DIR = CONF_DIR.joinpath("invesalius") |
28 | USER_PRESET_DIR = USER_INV_DIR.joinpath("presets") | 28 | USER_PRESET_DIR = USER_INV_DIR.joinpath("presets") |
29 | USER_LOG_DIR = USER_INV_DIR.joinpath("logs") | 29 | USER_LOG_DIR = USER_INV_DIR.joinpath("logs") |
30 | +USER_DL_WEIGHTS = USER_INV_DIR.joinpath("deep_learning/weights/") | ||
30 | USER_RAYCASTING_PRESETS_DIRECTORY = USER_PRESET_DIR.joinpath("raycasting") | 31 | USER_RAYCASTING_PRESETS_DIRECTORY = USER_PRESET_DIR.joinpath("raycasting") |
31 | TEMP_DIR = tempfile.gettempdir() | 32 | TEMP_DIR = tempfile.gettempdir() |
32 | 33 | ||
@@ -97,6 +98,7 @@ def create_conf_folders(): | @@ -97,6 +98,7 @@ def create_conf_folders(): | ||
97 | USER_INV_DIR.mkdir(parents=True, exist_ok=True) | 98 | USER_INV_DIR.mkdir(parents=True, exist_ok=True) |
98 | USER_PRESET_DIR.mkdir(parents=True, exist_ok=True) | 99 | USER_PRESET_DIR.mkdir(parents=True, exist_ok=True) |
99 | USER_LOG_DIR.mkdir(parents=True, exist_ok=True) | 100 | USER_LOG_DIR.mkdir(parents=True, exist_ok=True) |
101 | + USER_DL_WEIGHTS.mkdir(parents=True, exist_ok=True) | ||
100 | USER_PLUGINS_DIRECTORY.mkdir(parents=True, exist_ok=True) | 102 | USER_PLUGINS_DIRECTORY.mkdir(parents=True, exist_ok=True) |
101 | 103 | ||
102 | 104 |
@@ -0,0 +1,48 @@ | @@ -0,0 +1,48 @@ | ||
1 | +from urllib.error import HTTPError | ||
2 | +from urllib.request import urlopen, Request | ||
3 | +from urllib.parse import urlparse | ||
4 | +import pathlib | ||
5 | +import tempfile | ||
6 | +import typing | ||
7 | +import hashlib | ||
8 | +import os | ||
9 | +import shutil | ||
10 | + | ||
11 | +def download_url_to_file(url: str, dst: pathlib.Path, hash: str = None, callback: typing.Callable[[float], None] = None): | ||
12 | + file_size = None | ||
13 | + total_downloaded = 0 | ||
14 | + if hash is not None: | ||
15 | + calc_hash = hashlib.sha256() | ||
16 | + req = Request(url) | ||
17 | + response = urlopen(req) | ||
18 | + meta = response.info() | ||
19 | + if hasattr(meta, "getheaders"): | ||
20 | + content_length = meta.getheaders("Content-Length") | ||
21 | + else: | ||
22 | + content_length = meta.get_all("Content-Length") | ||
23 | + | ||
24 | + if content_length is not None and len(content_length) > 0: | ||
25 | + file_size = int(content_length[0]) | ||
26 | + dst.parent.mkdir(parents=True, exist_ok=True) | ||
27 | + f = tempfile.NamedTemporaryFile(delete=False, dir=dst.parent) | ||
28 | + try: | ||
29 | + while True: | ||
30 | + buffer = response.read(8192) | ||
31 | + if len(buffer) == 0: | ||
32 | + break | ||
33 | + total_downloaded += len(buffer) | ||
34 | + f.write(buffer) | ||
35 | + if hash: | ||
36 | + calc_hash.update(buffer) | ||
37 | + if callback is not None: | ||
38 | + callback(100 * total_downloaded/file_size) | ||
39 | + f.close() | ||
40 | + if hash is not None: | ||
41 | + digest = calc_hash.hexdigest() | ||
42 | + if digest != hash: | ||
43 | + raise RuntimeError(f'Invalid hash value (expected "{hash}", got "{digest}")') | ||
44 | + shutil.move(f.name, dst) | ||
45 | + finally: | ||
46 | + f.close() | ||
47 | + if os.path.exists(f.name): | ||
48 | + os.remove(f.name) |
@@ -0,0 +1,149 @@ | @@ -0,0 +1,149 @@ | ||
1 | +from collections import OrderedDict | ||
2 | + | ||
3 | +import torch | ||
4 | +import torch.nn as nn | ||
5 | + | ||
6 | +SIZE = 48 | ||
7 | + | ||
8 | +class Unet3D(nn.Module): | ||
9 | + # Based on https://github.com/mateuszbuda/brain-segmentation-pytorch/blob/master/unet.py | ||
10 | + def __init__(self, in_channels=1, out_channels=1, init_features=8): | ||
11 | + super().__init__() | ||
12 | + features = init_features | ||
13 | + | ||
14 | + self.encoder1 = self._block( | ||
15 | + in_channels, features=features, padding=2, name="enc1" | ||
16 | + ) | ||
17 | + self.pool1 = nn.MaxPool3d(kernel_size=2, stride=2) | ||
18 | + | ||
19 | + self.encoder2 = self._block( | ||
20 | + features, features=features * 2, padding=2, name="enc2" | ||
21 | + ) | ||
22 | + self.pool2 = nn.MaxPool3d(kernel_size=2, stride=2) | ||
23 | + | ||
24 | + self.encoder3 = self._block( | ||
25 | + features * 2, features=features * 4, padding=2, name="enc3" | ||
26 | + ) | ||
27 | + self.pool3 = nn.MaxPool3d(kernel_size=2, stride=2) | ||
28 | + | ||
29 | + self.encoder4 = self._block( | ||
30 | + features * 4, features=features * 8, padding=2, name="enc4" | ||
31 | + ) | ||
32 | + self.pool4 = nn.MaxPool3d(kernel_size=2, stride=2) | ||
33 | + | ||
34 | + self.bottleneck = self._block( | ||
35 | + features * 8, features=features * 16, padding=2, name="bottleneck" | ||
36 | + ) | ||
37 | + | ||
38 | + self.upconv4 = nn.ConvTranspose3d( | ||
39 | + features * 16, features * 8, kernel_size=4, stride=2, padding=1 | ||
40 | + ) | ||
41 | + self.decoder4 = self._block( | ||
42 | + features * 16, features=features * 8, padding=2, name="dec4" | ||
43 | + ) | ||
44 | + | ||
45 | + self.upconv3 = nn.ConvTranspose3d( | ||
46 | + features * 8, features * 4, kernel_size=4, stride=2, padding=1 | ||
47 | + ) | ||
48 | + self.decoder3 = self._block( | ||
49 | + features * 8, features=features * 4, padding=2, name="dec4" | ||
50 | + ) | ||
51 | + | ||
52 | + self.upconv2 = nn.ConvTranspose3d( | ||
53 | + features * 4, features * 2, kernel_size=4, stride=2, padding=1 | ||
54 | + ) | ||
55 | + self.decoder2 = self._block( | ||
56 | + features * 4, features=features * 2, padding=2, name="dec4" | ||
57 | + ) | ||
58 | + | ||
59 | + self.upconv1 = nn.ConvTranspose3d( | ||
60 | + features * 2, features, kernel_size=4, stride=2, padding=1 | ||
61 | + ) | ||
62 | + self.decoder1 = self._block( | ||
63 | + features * 2, features=features, padding=2, name="dec4" | ||
64 | + ) | ||
65 | + | ||
66 | + self.conv = nn.Conv3d( | ||
67 | + in_channels=features, out_channels=out_channels, kernel_size=1 | ||
68 | + ) | ||
69 | + | ||
70 | + def forward(self, img): | ||
71 | + enc1 = self.encoder1(img) | ||
72 | + enc2 = self.encoder2(self.pool1(enc1)) | ||
73 | + enc3 = self.encoder3(self.pool2(enc2)) | ||
74 | + enc4 = self.encoder4(self.pool3(enc3)) | ||
75 | + | ||
76 | + bottleneck = self.bottleneck(self.pool4(enc4)) | ||
77 | + | ||
78 | + upconv4 = self.upconv4(bottleneck) | ||
79 | + dec4 = torch.cat((upconv4, enc4), dim=1) | ||
80 | + dec4 = self.decoder4(dec4) | ||
81 | + | ||
82 | + upconv3 = self.upconv3(dec4) | ||
83 | + dec3 = torch.cat((upconv3, enc3), dim=1) | ||
84 | + dec3 = self.decoder3(dec3) | ||
85 | + | ||
86 | + upconv2 = self.upconv2(dec3) | ||
87 | + dec2 = torch.cat((upconv2, enc2), dim=1) | ||
88 | + dec2 = self.decoder2(dec2) | ||
89 | + | ||
90 | + upconv1 = self.upconv1(dec2) | ||
91 | + dec1 = torch.cat((upconv1, enc1), dim=1) | ||
92 | + dec1 = self.decoder1(dec1) | ||
93 | + | ||
94 | + conv = self.conv(dec1) | ||
95 | + | ||
96 | + sigmoid = torch.sigmoid(conv) | ||
97 | + | ||
98 | + return sigmoid | ||
99 | + | ||
100 | + def _block(self, in_channels, features, padding=1, kernel_size=5, name="block"): | ||
101 | + return nn.Sequential( | ||
102 | + OrderedDict( | ||
103 | + ( | ||
104 | + ( | ||
105 | + f"{name}_conv1", | ||
106 | + nn.Conv3d( | ||
107 | + in_channels=in_channels, | ||
108 | + out_channels=features, | ||
109 | + kernel_size=kernel_size, | ||
110 | + padding=padding, | ||
111 | + bias=True, | ||
112 | + ), | ||
113 | + ), | ||
114 | + (f"{name}_norm1", nn.BatchNorm3d(num_features=features)), | ||
115 | + (f"{name}_relu1", nn.ReLU(inplace=True)), | ||
116 | + ( | ||
117 | + f"{name}_conv2", | ||
118 | + nn.Conv3d( | ||
119 | + in_channels=features, | ||
120 | + out_channels=features, | ||
121 | + kernel_size=kernel_size, | ||
122 | + padding=padding, | ||
123 | + bias=True, | ||
124 | + ), | ||
125 | + ), | ||
126 | + (f"{name}_norm2", nn.BatchNorm3d(num_features=features)), | ||
127 | + (f"{name}_relu2", nn.ReLU(inplace=True)), | ||
128 | + ) | ||
129 | + ) | ||
130 | + ) | ||
131 | + | ||
132 | + | ||
133 | +def main(): | ||
134 | + import torchviz | ||
135 | + dev = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | ||
136 | + model = Unet3D() | ||
137 | + model.to(dev) | ||
138 | + model.eval() | ||
139 | + print(next(model.parameters()).is_cuda) # True | ||
140 | + img = torch.randn(1, SIZE, SIZE, SIZE, 1).to(dev) | ||
141 | + out = model(img) | ||
142 | + dot = torchviz.make_dot(out, params=dict(model.named_parameters()), show_attrs=True, show_saved=True) | ||
143 | + dot.render("unet", format="png") | ||
144 | + torch.save(model, "model.pth") | ||
145 | + print(dot) | ||
146 | + | ||
147 | + | ||
148 | +if __name__ == "__main__": | ||
149 | + main() |
invesalius/segmentation/brain/segment.py
@@ -13,6 +13,8 @@ import invesalius.data.slice_ as slc | @@ -13,6 +13,8 @@ import invesalius.data.slice_ as slc | ||
13 | from invesalius import inv_paths | 13 | from invesalius import inv_paths |
14 | from invesalius.data import imagedata_utils | 14 | from invesalius.data import imagedata_utils |
15 | from invesalius.utils import new_name_by_pattern | 15 | from invesalius.utils import new_name_by_pattern |
16 | +from invesalius.net.utils import download_url_to_file | ||
17 | +from invesalius import inv_paths | ||
16 | 18 | ||
17 | from . import utils | 19 | from . import utils |
18 | 20 | ||
@@ -64,6 +66,17 @@ def predict_patch(sub_image, patch, nn_model, patch_size=SIZE): | @@ -64,6 +66,17 @@ def predict_patch(sub_image, patch, nn_model, patch_size=SIZE): | ||
64 | 0 : ez - iz, 0 : ey - iy, 0 : ex - ix | 66 | 0 : ez - iz, 0 : ey - iy, 0 : ex - ix |
65 | ] | 67 | ] |
66 | 68 | ||
69 | +def predict_patch_torch(sub_image, patch, nn_model, device, patch_size=SIZE): | ||
70 | + import torch | ||
71 | + with torch.no_grad(): | ||
72 | + (iz, ez), (iy, ey), (ix, ex) = patch | ||
73 | + sub_mask = nn_model( | ||
74 | + torch.from_numpy(sub_image.reshape(1, 1, patch_size, patch_size, patch_size)).to(device) | ||
75 | + ).cpu().numpy() | ||
76 | + return sub_mask.reshape(patch_size, patch_size, patch_size)[ | ||
77 | + 0 : ez - iz, 0 : ey - iy, 0 : ex - ix | ||
78 | + ] | ||
79 | + | ||
67 | 80 | ||
68 | def brain_segment(image, probability_array, comm_array): | 81 | def brain_segment(image, probability_array, comm_array): |
69 | import keras | 82 | import keras |
@@ -89,6 +102,42 @@ def brain_segment(image, probability_array, comm_array): | @@ -89,6 +102,42 @@ def brain_segment(image, probability_array, comm_array): | ||
89 | comm_array[0] = np.Inf | 102 | comm_array[0] = np.Inf |
90 | 103 | ||
91 | 104 | ||
105 | +def download_callback(comm_array): | ||
106 | + def _download_callback(value): | ||
107 | + comm_array[0] = value | ||
108 | + return _download_callback | ||
109 | + | ||
110 | +def brain_segment_torch(image, device_id, probability_array, comm_array): | ||
111 | + import torch | ||
112 | + from .model import Unet3D | ||
113 | + device = torch.device(device_id) | ||
114 | + state_dict_file = inv_paths.USER_DL_WEIGHTS.joinpath("brain_mri_t1.pt") | ||
115 | + if not state_dict_file.exists(): | ||
116 | + download_url_to_file( | ||
117 | + "https://github.com/tfmoraes/deepbrain_torch/releases/download/v1.1.0/weights.pt", | ||
118 | + state_dict_file, | ||
119 | + "194b0305947c9326eeee9da34ada728435a13c7b24015cbd95971097fc178f22", | ||
120 | + download_callback(comm_array) | ||
121 | + ) | ||
122 | + state_dict = torch.load(str(state_dict_file)) | ||
123 | + model = Unet3D() | ||
124 | + model.load_state_dict(state_dict["model_state_dict"]) | ||
125 | + model.to(device) | ||
126 | + model.eval() | ||
127 | + | ||
128 | + image = imagedata_utils.image_normalize(image, 0.0, 1.0, output_dtype=np.float32) | ||
129 | + sums = np.zeros_like(image) | ||
130 | + # segmenting by patches | ||
131 | + for completion, sub_image, patch in gen_patches(image, SIZE, OVERLAP): | ||
132 | + comm_array[0] = completion | ||
133 | + (iz, ez), (iy, ey), (ix, ex) = patch | ||
134 | + sub_mask = predict_patch_torch(sub_image, patch, model, device, SIZE) | ||
135 | + probability_array[iz:ez, iy:ey, ix:ex] += sub_mask | ||
136 | + sums[iz:ez, iy:ey, ix:ex] += 1 | ||
137 | + | ||
138 | + probability_array /= sums | ||
139 | + comm_array[0] = np.Inf | ||
140 | + | ||
92 | ctx = multiprocessing.get_context('spawn') | 141 | ctx = multiprocessing.get_context('spawn') |
93 | class SegmentProcess(ctx.Process): | 142 | class SegmentProcess(ctx.Process): |
94 | def __init__(self, image, create_new_mask, backend, device_id, use_gpu, apply_wwwl=False, window_width=255, window_level=127): | 143 | def __init__(self, image, create_new_mask, backend, device_id, use_gpu, apply_wwwl=False, window_width=255, window_level=127): |
@@ -138,8 +187,7 @@ class SegmentProcess(ctx.Process): | @@ -138,8 +187,7 @@ class SegmentProcess(ctx.Process): | ||
138 | mode="r", | 187 | mode="r", |
139 | ) | 188 | ) |
140 | 189 | ||
141 | - print(image.min(), image.max()) | ||
142 | - if self.apply_segment_threshold: | 190 | + if self.apply_wwwl: |
143 | print("Applying window level") | 191 | print("Applying window level") |
144 | image = get_LUT_value(image, self.window_width, self.window_level) | 192 | image = get_LUT_value(image, self.window_width, self.window_level) |
145 | 193 | ||
@@ -153,8 +201,11 @@ class SegmentProcess(ctx.Process): | @@ -153,8 +201,11 @@ class SegmentProcess(ctx.Process): | ||
153 | self._comm_array_filename, dtype=np.float32, shape=(1,), mode="r+" | 201 | self._comm_array_filename, dtype=np.float32, shape=(1,), mode="r+" |
154 | ) | 202 | ) |
155 | 203 | ||
156 | - utils.prepare_ambient(self.backend, self.device_id, self.use_gpu) | ||
157 | - brain_segment(image, probability_array, comm_array) | 204 | + if self.backend.lower() == "pytorch": |
205 | + brain_segment_torch(image, self.device_id, probability_array, comm_array) | ||
206 | + else: | ||
207 | + utils.prepare_ambient(self.backend, self.device_id, self.use_gpu) | ||
208 | + brain_segment(image, probability_array, comm_array) | ||
158 | 209 | ||
159 | @property | 210 | @property |
160 | def exception(self): | 211 | def exception(self): |
requirements.txt