Commit a406f61e7cce53232cda3932c8981e993deb61fb
Committed by
GitHub
1 parent
9460b9cd
Exists in
master
Add pytorch backend to brain segmentation (#365)
* added pytorch to requirements.txt and showing in GUI * segmenting using pytorch * Remove debug prints * Added an function to download a file and save it localy * Downloading the weight using the developed function to download file
Showing
6 changed files
with
308 additions
and
13 deletions
Show diff stats
invesalius/gui/brain_seg_dialog.py
... | ... | @@ -22,6 +22,22 @@ HAS_THEANO = bool(importlib.util.find_spec("theano")) |
22 | 22 | HAS_PLAIDML = bool(importlib.util.find_spec("plaidml")) |
23 | 23 | PLAIDML_DEVICES = {} |
24 | 24 | |
25 | +try: | |
26 | + import torch | |
27 | + HAS_TORCH = True | |
28 | +except ImportError: | |
29 | + HAS_TORCH = False | |
30 | + | |
31 | +if HAS_TORCH: | |
32 | + TORCH_DEVICES = {} | |
33 | + if torch.cuda.is_available(): | |
34 | + for i in range(torch.cuda.device_count()): | |
35 | + name = torch.cuda.get_device_name() | |
36 | + device_id = f'cuda:{i}' | |
37 | + TORCH_DEVICES[name] = device_id | |
38 | + TORCH_DEVICES['CPU'] = 'cpu' | |
39 | + | |
40 | + | |
25 | 41 | |
26 | 42 | if HAS_PLAIDML: |
27 | 43 | with multiprocessing.Pool(1) as p: |
... | ... | @@ -43,12 +59,15 @@ class BrainSegmenterDialog(wx.Dialog): |
43 | 59 | style=wx.DEFAULT_DIALOG_STYLE | wx.FRAME_FLOAT_ON_PARENT, |
44 | 60 | ) |
45 | 61 | backends = [] |
62 | + if HAS_TORCH: | |
63 | + backends.append("Pytorch") | |
46 | 64 | if HAS_PLAIDML: |
47 | 65 | backends.append("PlaidML") |
48 | 66 | if HAS_THEANO: |
49 | 67 | backends.append("Theano") |
50 | 68 | # self.segmenter = segment.BrainSegmenter() |
51 | 69 | # self.pg_dialog = None |
70 | + self.torch_devices = TORCH_DEVICES | |
52 | 71 | self.plaidml_devices = PLAIDML_DEVICES |
53 | 72 | |
54 | 73 | self.ps = None |
... | ... | @@ -65,13 +84,19 @@ class BrainSegmenterDialog(wx.Dialog): |
65 | 84 | w, h = self.CalcSizeFromTextSize("MM" * (1 + max(len(i) for i in backends))) |
66 | 85 | self.cb_backends.SetMinClientSize((w, -1)) |
67 | 86 | self.chk_use_gpu = wx.CheckBox(self, wx.ID_ANY, _("Use GPU")) |
68 | - if HAS_PLAIDML: | |
87 | + if HAS_TORCH or HAS_PLAIDML: | |
88 | + if HAS_TORCH: | |
89 | + choices = list(self.torch_devices.keys()) | |
90 | + value = choices[0] | |
91 | + else: | |
92 | + choices = list(self.plaidml_devices.keys()) | |
93 | + value = choices[0] | |
69 | 94 | self.lbl_device = wx.StaticText(self, -1, _("Device")) |
70 | 95 | self.cb_devices = wx.ComboBox( |
71 | 96 | self, |
72 | 97 | wx.ID_ANY, |
73 | - choices=list(self.plaidml_devices.keys()), | |
74 | - value=list(self.plaidml_devices.keys())[0], | |
98 | + choices=choices, | |
99 | + value=value, | |
75 | 100 | style=wx.CB_DROPDOWN | wx.CB_READONLY, |
76 | 101 | ) |
77 | 102 | self.sld_threshold = wx.Slider(self, wx.ID_ANY, 75, 0, 100) |
... | ... | @@ -109,7 +134,7 @@ class BrainSegmenterDialog(wx.Dialog): |
109 | 134 | main_sizer.Add(sizer_backends, 0, wx.ALL | wx.EXPAND, 5) |
110 | 135 | main_sizer.Add(self.chk_use_gpu, 0, wx.ALL, 5) |
111 | 136 | sizer_devices = wx.BoxSizer(wx.HORIZONTAL) |
112 | - if HAS_PLAIDML: | |
137 | + if HAS_TORCH or HAS_PLAIDML: | |
113 | 138 | sizer_devices.Add(self.lbl_device, 0, wx.ALIGN_CENTER, 0) |
114 | 139 | sizer_devices.Add(self.cb_devices, 1, wx.LEFT, 5) |
115 | 140 | main_sizer.Add(sizer_devices, 0, wx.ALL | wx.EXPAND, 5) |
... | ... | @@ -177,8 +202,21 @@ class BrainSegmenterDialog(wx.Dialog): |
177 | 202 | return width, height |
178 | 203 | |
179 | 204 | def OnSetBackend(self, evt=None): |
180 | - if self.cb_backends.GetValue().lower() == "plaidml": | |
205 | + if self.cb_backends.GetValue().lower() == "pytorch": | |
206 | + if HAS_TORCH: | |
207 | + choices = list(self.torch_devices.keys()) | |
208 | + self.cb_devices.Clear() | |
209 | + self.cb_devices.SetItems(choices) | |
210 | + self.cb_devices.SetValue(choices[0]) | |
211 | + self.lbl_device.Show() | |
212 | + self.cb_devices.Show() | |
213 | + self.chk_use_gpu.Hide() | |
214 | + elif self.cb_backends.GetValue().lower() == "plaidml": | |
181 | 215 | if HAS_PLAIDML: |
216 | + choices = list(self.plaidml_devices.keys()) | |
217 | + self.cb_devices.Clear() | |
218 | + self.cb_devices.SetItems(choices) | |
219 | + self.cb_devices.SetValue(choices[0]) | |
182 | 220 | self.lbl_device.Show() |
183 | 221 | self.cb_devices.Show() |
184 | 222 | self.chk_use_gpu.Hide() |
... | ... | @@ -216,10 +254,16 @@ class BrainSegmenterDialog(wx.Dialog): |
216 | 254 | self.elapsed_time_timer.Start(1000) |
217 | 255 | image = slc.Slice().matrix |
218 | 256 | backend = self.cb_backends.GetValue() |
219 | - try: | |
220 | - device_id = self.plaidml_devices[self.cb_devices.GetValue()] | |
221 | - except (KeyError, AttributeError): | |
222 | - device_id = "llvm_cpu.0" | |
257 | + if backend.lower() == "pytorch": | |
258 | + try: | |
259 | + device_id = self.torch_devices[self.cb_devices.GetValue()] | |
260 | + except (KeyError, AttributeError): | |
261 | + device_id = "cpu" | |
262 | + else: | |
263 | + try: | |
264 | + device_id = self.plaidml_devices[self.cb_devices.GetValue()] | |
265 | + except (KeyError, AttributeError): | |
266 | + device_id = "llvm_cpu.0" | |
223 | 267 | apply_wwwl = self.chk_apply_wwwl.GetValue() |
224 | 268 | create_new_mask = self.chk_new_mask.GetValue() |
225 | 269 | use_gpu = self.chk_use_gpu.GetValue() | ... | ... |
invesalius/inv_paths.py
... | ... | @@ -27,6 +27,7 @@ CONF_DIR = pathlib.Path(os.environ.get("XDG_CONFIG_HOME", USER_DIR.joinpath(".co |
27 | 27 | USER_INV_DIR = CONF_DIR.joinpath("invesalius") |
28 | 28 | USER_PRESET_DIR = USER_INV_DIR.joinpath("presets") |
29 | 29 | USER_LOG_DIR = USER_INV_DIR.joinpath("logs") |
30 | +USER_DL_WEIGHTS = USER_INV_DIR.joinpath("deep_learning/weights/") | |
30 | 31 | USER_RAYCASTING_PRESETS_DIRECTORY = USER_PRESET_DIR.joinpath("raycasting") |
31 | 32 | TEMP_DIR = tempfile.gettempdir() |
32 | 33 | |
... | ... | @@ -97,6 +98,7 @@ def create_conf_folders(): |
97 | 98 | USER_INV_DIR.mkdir(parents=True, exist_ok=True) |
98 | 99 | USER_PRESET_DIR.mkdir(parents=True, exist_ok=True) |
99 | 100 | USER_LOG_DIR.mkdir(parents=True, exist_ok=True) |
101 | + USER_DL_WEIGHTS.mkdir(parents=True, exist_ok=True) | |
100 | 102 | USER_PLUGINS_DIRECTORY.mkdir(parents=True, exist_ok=True) |
101 | 103 | |
102 | 104 | ... | ... |
... | ... | @@ -0,0 +1,48 @@ |
1 | +from urllib.error import HTTPError | |
2 | +from urllib.request import urlopen, Request | |
3 | +from urllib.parse import urlparse | |
4 | +import pathlib | |
5 | +import tempfile | |
6 | +import typing | |
7 | +import hashlib | |
8 | +import os | |
9 | +import shutil | |
10 | + | |
11 | +def download_url_to_file(url: str, dst: pathlib.Path, hash: str = None, callback: typing.Callable[[float], None] = None): | |
12 | + file_size = None | |
13 | + total_downloaded = 0 | |
14 | + if hash is not None: | |
15 | + calc_hash = hashlib.sha256() | |
16 | + req = Request(url) | |
17 | + response = urlopen(req) | |
18 | + meta = response.info() | |
19 | + if hasattr(meta, "getheaders"): | |
20 | + content_length = meta.getheaders("Content-Length") | |
21 | + else: | |
22 | + content_length = meta.get_all("Content-Length") | |
23 | + | |
24 | + if content_length is not None and len(content_length) > 0: | |
25 | + file_size = int(content_length[0]) | |
26 | + dst.parent.mkdir(parents=True, exist_ok=True) | |
27 | + f = tempfile.NamedTemporaryFile(delete=False, dir=dst.parent) | |
28 | + try: | |
29 | + while True: | |
30 | + buffer = response.read(8192) | |
31 | + if len(buffer) == 0: | |
32 | + break | |
33 | + total_downloaded += len(buffer) | |
34 | + f.write(buffer) | |
35 | + if hash: | |
36 | + calc_hash.update(buffer) | |
37 | + if callback is not None: | |
38 | + callback(100 * total_downloaded/file_size) | |
39 | + f.close() | |
40 | + if hash is not None: | |
41 | + digest = calc_hash.hexdigest() | |
42 | + if digest != hash: | |
43 | + raise RuntimeError(f'Invalid hash value (expected "{hash}", got "{digest}")') | |
44 | + shutil.move(f.name, dst) | |
45 | + finally: | |
46 | + f.close() | |
47 | + if os.path.exists(f.name): | |
48 | + os.remove(f.name) | ... | ... |
... | ... | @@ -0,0 +1,149 @@ |
1 | +from collections import OrderedDict | |
2 | + | |
3 | +import torch | |
4 | +import torch.nn as nn | |
5 | + | |
6 | +SIZE = 48 | |
7 | + | |
8 | +class Unet3D(nn.Module): | |
9 | + # Based on https://github.com/mateuszbuda/brain-segmentation-pytorch/blob/master/unet.py | |
10 | + def __init__(self, in_channels=1, out_channels=1, init_features=8): | |
11 | + super().__init__() | |
12 | + features = init_features | |
13 | + | |
14 | + self.encoder1 = self._block( | |
15 | + in_channels, features=features, padding=2, name="enc1" | |
16 | + ) | |
17 | + self.pool1 = nn.MaxPool3d(kernel_size=2, stride=2) | |
18 | + | |
19 | + self.encoder2 = self._block( | |
20 | + features, features=features * 2, padding=2, name="enc2" | |
21 | + ) | |
22 | + self.pool2 = nn.MaxPool3d(kernel_size=2, stride=2) | |
23 | + | |
24 | + self.encoder3 = self._block( | |
25 | + features * 2, features=features * 4, padding=2, name="enc3" | |
26 | + ) | |
27 | + self.pool3 = nn.MaxPool3d(kernel_size=2, stride=2) | |
28 | + | |
29 | + self.encoder4 = self._block( | |
30 | + features * 4, features=features * 8, padding=2, name="enc4" | |
31 | + ) | |
32 | + self.pool4 = nn.MaxPool3d(kernel_size=2, stride=2) | |
33 | + | |
34 | + self.bottleneck = self._block( | |
35 | + features * 8, features=features * 16, padding=2, name="bottleneck" | |
36 | + ) | |
37 | + | |
38 | + self.upconv4 = nn.ConvTranspose3d( | |
39 | + features * 16, features * 8, kernel_size=4, stride=2, padding=1 | |
40 | + ) | |
41 | + self.decoder4 = self._block( | |
42 | + features * 16, features=features * 8, padding=2, name="dec4" | |
43 | + ) | |
44 | + | |
45 | + self.upconv3 = nn.ConvTranspose3d( | |
46 | + features * 8, features * 4, kernel_size=4, stride=2, padding=1 | |
47 | + ) | |
48 | + self.decoder3 = self._block( | |
49 | + features * 8, features=features * 4, padding=2, name="dec4" | |
50 | + ) | |
51 | + | |
52 | + self.upconv2 = nn.ConvTranspose3d( | |
53 | + features * 4, features * 2, kernel_size=4, stride=2, padding=1 | |
54 | + ) | |
55 | + self.decoder2 = self._block( | |
56 | + features * 4, features=features * 2, padding=2, name="dec4" | |
57 | + ) | |
58 | + | |
59 | + self.upconv1 = nn.ConvTranspose3d( | |
60 | + features * 2, features, kernel_size=4, stride=2, padding=1 | |
61 | + ) | |
62 | + self.decoder1 = self._block( | |
63 | + features * 2, features=features, padding=2, name="dec4" | |
64 | + ) | |
65 | + | |
66 | + self.conv = nn.Conv3d( | |
67 | + in_channels=features, out_channels=out_channels, kernel_size=1 | |
68 | + ) | |
69 | + | |
70 | + def forward(self, img): | |
71 | + enc1 = self.encoder1(img) | |
72 | + enc2 = self.encoder2(self.pool1(enc1)) | |
73 | + enc3 = self.encoder3(self.pool2(enc2)) | |
74 | + enc4 = self.encoder4(self.pool3(enc3)) | |
75 | + | |
76 | + bottleneck = self.bottleneck(self.pool4(enc4)) | |
77 | + | |
78 | + upconv4 = self.upconv4(bottleneck) | |
79 | + dec4 = torch.cat((upconv4, enc4), dim=1) | |
80 | + dec4 = self.decoder4(dec4) | |
81 | + | |
82 | + upconv3 = self.upconv3(dec4) | |
83 | + dec3 = torch.cat((upconv3, enc3), dim=1) | |
84 | + dec3 = self.decoder3(dec3) | |
85 | + | |
86 | + upconv2 = self.upconv2(dec3) | |
87 | + dec2 = torch.cat((upconv2, enc2), dim=1) | |
88 | + dec2 = self.decoder2(dec2) | |
89 | + | |
90 | + upconv1 = self.upconv1(dec2) | |
91 | + dec1 = torch.cat((upconv1, enc1), dim=1) | |
92 | + dec1 = self.decoder1(dec1) | |
93 | + | |
94 | + conv = self.conv(dec1) | |
95 | + | |
96 | + sigmoid = torch.sigmoid(conv) | |
97 | + | |
98 | + return sigmoid | |
99 | + | |
100 | + def _block(self, in_channels, features, padding=1, kernel_size=5, name="block"): | |
101 | + return nn.Sequential( | |
102 | + OrderedDict( | |
103 | + ( | |
104 | + ( | |
105 | + f"{name}_conv1", | |
106 | + nn.Conv3d( | |
107 | + in_channels=in_channels, | |
108 | + out_channels=features, | |
109 | + kernel_size=kernel_size, | |
110 | + padding=padding, | |
111 | + bias=True, | |
112 | + ), | |
113 | + ), | |
114 | + (f"{name}_norm1", nn.BatchNorm3d(num_features=features)), | |
115 | + (f"{name}_relu1", nn.ReLU(inplace=True)), | |
116 | + ( | |
117 | + f"{name}_conv2", | |
118 | + nn.Conv3d( | |
119 | + in_channels=features, | |
120 | + out_channels=features, | |
121 | + kernel_size=kernel_size, | |
122 | + padding=padding, | |
123 | + bias=True, | |
124 | + ), | |
125 | + ), | |
126 | + (f"{name}_norm2", nn.BatchNorm3d(num_features=features)), | |
127 | + (f"{name}_relu2", nn.ReLU(inplace=True)), | |
128 | + ) | |
129 | + ) | |
130 | + ) | |
131 | + | |
132 | + | |
133 | +def main(): | |
134 | + import torchviz | |
135 | + dev = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
136 | + model = Unet3D() | |
137 | + model.to(dev) | |
138 | + model.eval() | |
139 | + print(next(model.parameters()).is_cuda) # True | |
140 | + img = torch.randn(1, SIZE, SIZE, SIZE, 1).to(dev) | |
141 | + out = model(img) | |
142 | + dot = torchviz.make_dot(out, params=dict(model.named_parameters()), show_attrs=True, show_saved=True) | |
143 | + dot.render("unet", format="png") | |
144 | + torch.save(model, "model.pth") | |
145 | + print(dot) | |
146 | + | |
147 | + | |
148 | +if __name__ == "__main__": | |
149 | + main() | ... | ... |
invesalius/segmentation/brain/segment.py
... | ... | @@ -13,6 +13,8 @@ import invesalius.data.slice_ as slc |
13 | 13 | from invesalius import inv_paths |
14 | 14 | from invesalius.data import imagedata_utils |
15 | 15 | from invesalius.utils import new_name_by_pattern |
16 | +from invesalius.net.utils import download_url_to_file | |
17 | +from invesalius import inv_paths | |
16 | 18 | |
17 | 19 | from . import utils |
18 | 20 | |
... | ... | @@ -64,6 +66,17 @@ def predict_patch(sub_image, patch, nn_model, patch_size=SIZE): |
64 | 66 | 0 : ez - iz, 0 : ey - iy, 0 : ex - ix |
65 | 67 | ] |
66 | 68 | |
69 | +def predict_patch_torch(sub_image, patch, nn_model, device, patch_size=SIZE): | |
70 | + import torch | |
71 | + with torch.no_grad(): | |
72 | + (iz, ez), (iy, ey), (ix, ex) = patch | |
73 | + sub_mask = nn_model( | |
74 | + torch.from_numpy(sub_image.reshape(1, 1, patch_size, patch_size, patch_size)).to(device) | |
75 | + ).cpu().numpy() | |
76 | + return sub_mask.reshape(patch_size, patch_size, patch_size)[ | |
77 | + 0 : ez - iz, 0 : ey - iy, 0 : ex - ix | |
78 | + ] | |
79 | + | |
67 | 80 | |
68 | 81 | def brain_segment(image, probability_array, comm_array): |
69 | 82 | import keras |
... | ... | @@ -89,6 +102,42 @@ def brain_segment(image, probability_array, comm_array): |
89 | 102 | comm_array[0] = np.Inf |
90 | 103 | |
91 | 104 | |
105 | +def download_callback(comm_array): | |
106 | + def _download_callback(value): | |
107 | + comm_array[0] = value | |
108 | + return _download_callback | |
109 | + | |
110 | +def brain_segment_torch(image, device_id, probability_array, comm_array): | |
111 | + import torch | |
112 | + from .model import Unet3D | |
113 | + device = torch.device(device_id) | |
114 | + state_dict_file = inv_paths.USER_DL_WEIGHTS.joinpath("brain_mri_t1.pt") | |
115 | + if not state_dict_file.exists(): | |
116 | + download_url_to_file( | |
117 | + "https://github.com/tfmoraes/deepbrain_torch/releases/download/v1.1.0/weights.pt", | |
118 | + state_dict_file, | |
119 | + "194b0305947c9326eeee9da34ada728435a13c7b24015cbd95971097fc178f22", | |
120 | + download_callback(comm_array) | |
121 | + ) | |
122 | + state_dict = torch.load(str(state_dict_file)) | |
123 | + model = Unet3D() | |
124 | + model.load_state_dict(state_dict["model_state_dict"]) | |
125 | + model.to(device) | |
126 | + model.eval() | |
127 | + | |
128 | + image = imagedata_utils.image_normalize(image, 0.0, 1.0, output_dtype=np.float32) | |
129 | + sums = np.zeros_like(image) | |
130 | + # segmenting by patches | |
131 | + for completion, sub_image, patch in gen_patches(image, SIZE, OVERLAP): | |
132 | + comm_array[0] = completion | |
133 | + (iz, ez), (iy, ey), (ix, ex) = patch | |
134 | + sub_mask = predict_patch_torch(sub_image, patch, model, device, SIZE) | |
135 | + probability_array[iz:ez, iy:ey, ix:ex] += sub_mask | |
136 | + sums[iz:ez, iy:ey, ix:ex] += 1 | |
137 | + | |
138 | + probability_array /= sums | |
139 | + comm_array[0] = np.Inf | |
140 | + | |
92 | 141 | ctx = multiprocessing.get_context('spawn') |
93 | 142 | class SegmentProcess(ctx.Process): |
94 | 143 | def __init__(self, image, create_new_mask, backend, device_id, use_gpu, apply_wwwl=False, window_width=255, window_level=127): |
... | ... | @@ -138,8 +187,7 @@ class SegmentProcess(ctx.Process): |
138 | 187 | mode="r", |
139 | 188 | ) |
140 | 189 | |
141 | - print(image.min(), image.max()) | |
142 | - if self.apply_segment_threshold: | |
190 | + if self.apply_wwwl: | |
143 | 191 | print("Applying window level") |
144 | 192 | image = get_LUT_value(image, self.window_width, self.window_level) |
145 | 193 | |
... | ... | @@ -153,8 +201,11 @@ class SegmentProcess(ctx.Process): |
153 | 201 | self._comm_array_filename, dtype=np.float32, shape=(1,), mode="r+" |
154 | 202 | ) |
155 | 203 | |
156 | - utils.prepare_ambient(self.backend, self.device_id, self.use_gpu) | |
157 | - brain_segment(image, probability_array, comm_array) | |
204 | + if self.backend.lower() == "pytorch": | |
205 | + brain_segment_torch(image, self.device_id, probability_array, comm_array) | |
206 | + else: | |
207 | + utils.prepare_ambient(self.backend, self.device_id, self.use_gpu) | |
208 | + brain_segment(image, probability_array, comm_array) | |
158 | 209 | |
159 | 210 | @property |
160 | 211 | def exception(self): | ... | ... |
requirements.txt