Commit a406f61e7cce53232cda3932c8981e993deb61fb
Committed by
GitHub
1 parent
9460b9cd
Exists in
master
Add pytorch backend to brain segmentation (#365)
* added pytorch to requirements.txt and showing in GUI * segmenting using pytorch * Remove debug prints * Added an function to download a file and save it localy * Downloading the weight using the developed function to download file
Showing
6 changed files
with
308 additions
and
13 deletions
Show diff stats
invesalius/gui/brain_seg_dialog.py
| ... | ... | @@ -22,6 +22,22 @@ HAS_THEANO = bool(importlib.util.find_spec("theano")) |
| 22 | 22 | HAS_PLAIDML = bool(importlib.util.find_spec("plaidml")) |
| 23 | 23 | PLAIDML_DEVICES = {} |
| 24 | 24 | |
| 25 | +try: | |
| 26 | + import torch | |
| 27 | + HAS_TORCH = True | |
| 28 | +except ImportError: | |
| 29 | + HAS_TORCH = False | |
| 30 | + | |
| 31 | +if HAS_TORCH: | |
| 32 | + TORCH_DEVICES = {} | |
| 33 | + if torch.cuda.is_available(): | |
| 34 | + for i in range(torch.cuda.device_count()): | |
| 35 | + name = torch.cuda.get_device_name() | |
| 36 | + device_id = f'cuda:{i}' | |
| 37 | + TORCH_DEVICES[name] = device_id | |
| 38 | + TORCH_DEVICES['CPU'] = 'cpu' | |
| 39 | + | |
| 40 | + | |
| 25 | 41 | |
| 26 | 42 | if HAS_PLAIDML: |
| 27 | 43 | with multiprocessing.Pool(1) as p: |
| ... | ... | @@ -43,12 +59,15 @@ class BrainSegmenterDialog(wx.Dialog): |
| 43 | 59 | style=wx.DEFAULT_DIALOG_STYLE | wx.FRAME_FLOAT_ON_PARENT, |
| 44 | 60 | ) |
| 45 | 61 | backends = [] |
| 62 | + if HAS_TORCH: | |
| 63 | + backends.append("Pytorch") | |
| 46 | 64 | if HAS_PLAIDML: |
| 47 | 65 | backends.append("PlaidML") |
| 48 | 66 | if HAS_THEANO: |
| 49 | 67 | backends.append("Theano") |
| 50 | 68 | # self.segmenter = segment.BrainSegmenter() |
| 51 | 69 | # self.pg_dialog = None |
| 70 | + self.torch_devices = TORCH_DEVICES | |
| 52 | 71 | self.plaidml_devices = PLAIDML_DEVICES |
| 53 | 72 | |
| 54 | 73 | self.ps = None |
| ... | ... | @@ -65,13 +84,19 @@ class BrainSegmenterDialog(wx.Dialog): |
| 65 | 84 | w, h = self.CalcSizeFromTextSize("MM" * (1 + max(len(i) for i in backends))) |
| 66 | 85 | self.cb_backends.SetMinClientSize((w, -1)) |
| 67 | 86 | self.chk_use_gpu = wx.CheckBox(self, wx.ID_ANY, _("Use GPU")) |
| 68 | - if HAS_PLAIDML: | |
| 87 | + if HAS_TORCH or HAS_PLAIDML: | |
| 88 | + if HAS_TORCH: | |
| 89 | + choices = list(self.torch_devices.keys()) | |
| 90 | + value = choices[0] | |
| 91 | + else: | |
| 92 | + choices = list(self.plaidml_devices.keys()) | |
| 93 | + value = choices[0] | |
| 69 | 94 | self.lbl_device = wx.StaticText(self, -1, _("Device")) |
| 70 | 95 | self.cb_devices = wx.ComboBox( |
| 71 | 96 | self, |
| 72 | 97 | wx.ID_ANY, |
| 73 | - choices=list(self.plaidml_devices.keys()), | |
| 74 | - value=list(self.plaidml_devices.keys())[0], | |
| 98 | + choices=choices, | |
| 99 | + value=value, | |
| 75 | 100 | style=wx.CB_DROPDOWN | wx.CB_READONLY, |
| 76 | 101 | ) |
| 77 | 102 | self.sld_threshold = wx.Slider(self, wx.ID_ANY, 75, 0, 100) |
| ... | ... | @@ -109,7 +134,7 @@ class BrainSegmenterDialog(wx.Dialog): |
| 109 | 134 | main_sizer.Add(sizer_backends, 0, wx.ALL | wx.EXPAND, 5) |
| 110 | 135 | main_sizer.Add(self.chk_use_gpu, 0, wx.ALL, 5) |
| 111 | 136 | sizer_devices = wx.BoxSizer(wx.HORIZONTAL) |
| 112 | - if HAS_PLAIDML: | |
| 137 | + if HAS_TORCH or HAS_PLAIDML: | |
| 113 | 138 | sizer_devices.Add(self.lbl_device, 0, wx.ALIGN_CENTER, 0) |
| 114 | 139 | sizer_devices.Add(self.cb_devices, 1, wx.LEFT, 5) |
| 115 | 140 | main_sizer.Add(sizer_devices, 0, wx.ALL | wx.EXPAND, 5) |
| ... | ... | @@ -177,8 +202,21 @@ class BrainSegmenterDialog(wx.Dialog): |
| 177 | 202 | return width, height |
| 178 | 203 | |
| 179 | 204 | def OnSetBackend(self, evt=None): |
| 180 | - if self.cb_backends.GetValue().lower() == "plaidml": | |
| 205 | + if self.cb_backends.GetValue().lower() == "pytorch": | |
| 206 | + if HAS_TORCH: | |
| 207 | + choices = list(self.torch_devices.keys()) | |
| 208 | + self.cb_devices.Clear() | |
| 209 | + self.cb_devices.SetItems(choices) | |
| 210 | + self.cb_devices.SetValue(choices[0]) | |
| 211 | + self.lbl_device.Show() | |
| 212 | + self.cb_devices.Show() | |
| 213 | + self.chk_use_gpu.Hide() | |
| 214 | + elif self.cb_backends.GetValue().lower() == "plaidml": | |
| 181 | 215 | if HAS_PLAIDML: |
| 216 | + choices = list(self.plaidml_devices.keys()) | |
| 217 | + self.cb_devices.Clear() | |
| 218 | + self.cb_devices.SetItems(choices) | |
| 219 | + self.cb_devices.SetValue(choices[0]) | |
| 182 | 220 | self.lbl_device.Show() |
| 183 | 221 | self.cb_devices.Show() |
| 184 | 222 | self.chk_use_gpu.Hide() |
| ... | ... | @@ -216,10 +254,16 @@ class BrainSegmenterDialog(wx.Dialog): |
| 216 | 254 | self.elapsed_time_timer.Start(1000) |
| 217 | 255 | image = slc.Slice().matrix |
| 218 | 256 | backend = self.cb_backends.GetValue() |
| 219 | - try: | |
| 220 | - device_id = self.plaidml_devices[self.cb_devices.GetValue()] | |
| 221 | - except (KeyError, AttributeError): | |
| 222 | - device_id = "llvm_cpu.0" | |
| 257 | + if backend.lower() == "pytorch": | |
| 258 | + try: | |
| 259 | + device_id = self.torch_devices[self.cb_devices.GetValue()] | |
| 260 | + except (KeyError, AttributeError): | |
| 261 | + device_id = "cpu" | |
| 262 | + else: | |
| 263 | + try: | |
| 264 | + device_id = self.plaidml_devices[self.cb_devices.GetValue()] | |
| 265 | + except (KeyError, AttributeError): | |
| 266 | + device_id = "llvm_cpu.0" | |
| 223 | 267 | apply_wwwl = self.chk_apply_wwwl.GetValue() |
| 224 | 268 | create_new_mask = self.chk_new_mask.GetValue() |
| 225 | 269 | use_gpu = self.chk_use_gpu.GetValue() | ... | ... |
invesalius/inv_paths.py
| ... | ... | @@ -27,6 +27,7 @@ CONF_DIR = pathlib.Path(os.environ.get("XDG_CONFIG_HOME", USER_DIR.joinpath(".co |
| 27 | 27 | USER_INV_DIR = CONF_DIR.joinpath("invesalius") |
| 28 | 28 | USER_PRESET_DIR = USER_INV_DIR.joinpath("presets") |
| 29 | 29 | USER_LOG_DIR = USER_INV_DIR.joinpath("logs") |
| 30 | +USER_DL_WEIGHTS = USER_INV_DIR.joinpath("deep_learning/weights/") | |
| 30 | 31 | USER_RAYCASTING_PRESETS_DIRECTORY = USER_PRESET_DIR.joinpath("raycasting") |
| 31 | 32 | TEMP_DIR = tempfile.gettempdir() |
| 32 | 33 | |
| ... | ... | @@ -97,6 +98,7 @@ def create_conf_folders(): |
| 97 | 98 | USER_INV_DIR.mkdir(parents=True, exist_ok=True) |
| 98 | 99 | USER_PRESET_DIR.mkdir(parents=True, exist_ok=True) |
| 99 | 100 | USER_LOG_DIR.mkdir(parents=True, exist_ok=True) |
| 101 | + USER_DL_WEIGHTS.mkdir(parents=True, exist_ok=True) | |
| 100 | 102 | USER_PLUGINS_DIRECTORY.mkdir(parents=True, exist_ok=True) |
| 101 | 103 | |
| 102 | 104 | ... | ... |
| ... | ... | @@ -0,0 +1,48 @@ |
| 1 | +from urllib.error import HTTPError | |
| 2 | +from urllib.request import urlopen, Request | |
| 3 | +from urllib.parse import urlparse | |
| 4 | +import pathlib | |
| 5 | +import tempfile | |
| 6 | +import typing | |
| 7 | +import hashlib | |
| 8 | +import os | |
| 9 | +import shutil | |
| 10 | + | |
| 11 | +def download_url_to_file(url: str, dst: pathlib.Path, hash: str = None, callback: typing.Callable[[float], None] = None): | |
| 12 | + file_size = None | |
| 13 | + total_downloaded = 0 | |
| 14 | + if hash is not None: | |
| 15 | + calc_hash = hashlib.sha256() | |
| 16 | + req = Request(url) | |
| 17 | + response = urlopen(req) | |
| 18 | + meta = response.info() | |
| 19 | + if hasattr(meta, "getheaders"): | |
| 20 | + content_length = meta.getheaders("Content-Length") | |
| 21 | + else: | |
| 22 | + content_length = meta.get_all("Content-Length") | |
| 23 | + | |
| 24 | + if content_length is not None and len(content_length) > 0: | |
| 25 | + file_size = int(content_length[0]) | |
| 26 | + dst.parent.mkdir(parents=True, exist_ok=True) | |
| 27 | + f = tempfile.NamedTemporaryFile(delete=False, dir=dst.parent) | |
| 28 | + try: | |
| 29 | + while True: | |
| 30 | + buffer = response.read(8192) | |
| 31 | + if len(buffer) == 0: | |
| 32 | + break | |
| 33 | + total_downloaded += len(buffer) | |
| 34 | + f.write(buffer) | |
| 35 | + if hash: | |
| 36 | + calc_hash.update(buffer) | |
| 37 | + if callback is not None: | |
| 38 | + callback(100 * total_downloaded/file_size) | |
| 39 | + f.close() | |
| 40 | + if hash is not None: | |
| 41 | + digest = calc_hash.hexdigest() | |
| 42 | + if digest != hash: | |
| 43 | + raise RuntimeError(f'Invalid hash value (expected "{hash}", got "{digest}")') | |
| 44 | + shutil.move(f.name, dst) | |
| 45 | + finally: | |
| 46 | + f.close() | |
| 47 | + if os.path.exists(f.name): | |
| 48 | + os.remove(f.name) | ... | ... |
| ... | ... | @@ -0,0 +1,149 @@ |
| 1 | +from collections import OrderedDict | |
| 2 | + | |
| 3 | +import torch | |
| 4 | +import torch.nn as nn | |
| 5 | + | |
| 6 | +SIZE = 48 | |
| 7 | + | |
| 8 | +class Unet3D(nn.Module): | |
| 9 | + # Based on https://github.com/mateuszbuda/brain-segmentation-pytorch/blob/master/unet.py | |
| 10 | + def __init__(self, in_channels=1, out_channels=1, init_features=8): | |
| 11 | + super().__init__() | |
| 12 | + features = init_features | |
| 13 | + | |
| 14 | + self.encoder1 = self._block( | |
| 15 | + in_channels, features=features, padding=2, name="enc1" | |
| 16 | + ) | |
| 17 | + self.pool1 = nn.MaxPool3d(kernel_size=2, stride=2) | |
| 18 | + | |
| 19 | + self.encoder2 = self._block( | |
| 20 | + features, features=features * 2, padding=2, name="enc2" | |
| 21 | + ) | |
| 22 | + self.pool2 = nn.MaxPool3d(kernel_size=2, stride=2) | |
| 23 | + | |
| 24 | + self.encoder3 = self._block( | |
| 25 | + features * 2, features=features * 4, padding=2, name="enc3" | |
| 26 | + ) | |
| 27 | + self.pool3 = nn.MaxPool3d(kernel_size=2, stride=2) | |
| 28 | + | |
| 29 | + self.encoder4 = self._block( | |
| 30 | + features * 4, features=features * 8, padding=2, name="enc4" | |
| 31 | + ) | |
| 32 | + self.pool4 = nn.MaxPool3d(kernel_size=2, stride=2) | |
| 33 | + | |
| 34 | + self.bottleneck = self._block( | |
| 35 | + features * 8, features=features * 16, padding=2, name="bottleneck" | |
| 36 | + ) | |
| 37 | + | |
| 38 | + self.upconv4 = nn.ConvTranspose3d( | |
| 39 | + features * 16, features * 8, kernel_size=4, stride=2, padding=1 | |
| 40 | + ) | |
| 41 | + self.decoder4 = self._block( | |
| 42 | + features * 16, features=features * 8, padding=2, name="dec4" | |
| 43 | + ) | |
| 44 | + | |
| 45 | + self.upconv3 = nn.ConvTranspose3d( | |
| 46 | + features * 8, features * 4, kernel_size=4, stride=2, padding=1 | |
| 47 | + ) | |
| 48 | + self.decoder3 = self._block( | |
| 49 | + features * 8, features=features * 4, padding=2, name="dec4" | |
| 50 | + ) | |
| 51 | + | |
| 52 | + self.upconv2 = nn.ConvTranspose3d( | |
| 53 | + features * 4, features * 2, kernel_size=4, stride=2, padding=1 | |
| 54 | + ) | |
| 55 | + self.decoder2 = self._block( | |
| 56 | + features * 4, features=features * 2, padding=2, name="dec4" | |
| 57 | + ) | |
| 58 | + | |
| 59 | + self.upconv1 = nn.ConvTranspose3d( | |
| 60 | + features * 2, features, kernel_size=4, stride=2, padding=1 | |
| 61 | + ) | |
| 62 | + self.decoder1 = self._block( | |
| 63 | + features * 2, features=features, padding=2, name="dec4" | |
| 64 | + ) | |
| 65 | + | |
| 66 | + self.conv = nn.Conv3d( | |
| 67 | + in_channels=features, out_channels=out_channels, kernel_size=1 | |
| 68 | + ) | |
| 69 | + | |
| 70 | + def forward(self, img): | |
| 71 | + enc1 = self.encoder1(img) | |
| 72 | + enc2 = self.encoder2(self.pool1(enc1)) | |
| 73 | + enc3 = self.encoder3(self.pool2(enc2)) | |
| 74 | + enc4 = self.encoder4(self.pool3(enc3)) | |
| 75 | + | |
| 76 | + bottleneck = self.bottleneck(self.pool4(enc4)) | |
| 77 | + | |
| 78 | + upconv4 = self.upconv4(bottleneck) | |
| 79 | + dec4 = torch.cat((upconv4, enc4), dim=1) | |
| 80 | + dec4 = self.decoder4(dec4) | |
| 81 | + | |
| 82 | + upconv3 = self.upconv3(dec4) | |
| 83 | + dec3 = torch.cat((upconv3, enc3), dim=1) | |
| 84 | + dec3 = self.decoder3(dec3) | |
| 85 | + | |
| 86 | + upconv2 = self.upconv2(dec3) | |
| 87 | + dec2 = torch.cat((upconv2, enc2), dim=1) | |
| 88 | + dec2 = self.decoder2(dec2) | |
| 89 | + | |
| 90 | + upconv1 = self.upconv1(dec2) | |
| 91 | + dec1 = torch.cat((upconv1, enc1), dim=1) | |
| 92 | + dec1 = self.decoder1(dec1) | |
| 93 | + | |
| 94 | + conv = self.conv(dec1) | |
| 95 | + | |
| 96 | + sigmoid = torch.sigmoid(conv) | |
| 97 | + | |
| 98 | + return sigmoid | |
| 99 | + | |
| 100 | + def _block(self, in_channels, features, padding=1, kernel_size=5, name="block"): | |
| 101 | + return nn.Sequential( | |
| 102 | + OrderedDict( | |
| 103 | + ( | |
| 104 | + ( | |
| 105 | + f"{name}_conv1", | |
| 106 | + nn.Conv3d( | |
| 107 | + in_channels=in_channels, | |
| 108 | + out_channels=features, | |
| 109 | + kernel_size=kernel_size, | |
| 110 | + padding=padding, | |
| 111 | + bias=True, | |
| 112 | + ), | |
| 113 | + ), | |
| 114 | + (f"{name}_norm1", nn.BatchNorm3d(num_features=features)), | |
| 115 | + (f"{name}_relu1", nn.ReLU(inplace=True)), | |
| 116 | + ( | |
| 117 | + f"{name}_conv2", | |
| 118 | + nn.Conv3d( | |
| 119 | + in_channels=features, | |
| 120 | + out_channels=features, | |
| 121 | + kernel_size=kernel_size, | |
| 122 | + padding=padding, | |
| 123 | + bias=True, | |
| 124 | + ), | |
| 125 | + ), | |
| 126 | + (f"{name}_norm2", nn.BatchNorm3d(num_features=features)), | |
| 127 | + (f"{name}_relu2", nn.ReLU(inplace=True)), | |
| 128 | + ) | |
| 129 | + ) | |
| 130 | + ) | |
| 131 | + | |
| 132 | + | |
| 133 | +def main(): | |
| 134 | + import torchviz | |
| 135 | + dev = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
| 136 | + model = Unet3D() | |
| 137 | + model.to(dev) | |
| 138 | + model.eval() | |
| 139 | + print(next(model.parameters()).is_cuda) # True | |
| 140 | + img = torch.randn(1, SIZE, SIZE, SIZE, 1).to(dev) | |
| 141 | + out = model(img) | |
| 142 | + dot = torchviz.make_dot(out, params=dict(model.named_parameters()), show_attrs=True, show_saved=True) | |
| 143 | + dot.render("unet", format="png") | |
| 144 | + torch.save(model, "model.pth") | |
| 145 | + print(dot) | |
| 146 | + | |
| 147 | + | |
| 148 | +if __name__ == "__main__": | |
| 149 | + main() | ... | ... |
invesalius/segmentation/brain/segment.py
| ... | ... | @@ -13,6 +13,8 @@ import invesalius.data.slice_ as slc |
| 13 | 13 | from invesalius import inv_paths |
| 14 | 14 | from invesalius.data import imagedata_utils |
| 15 | 15 | from invesalius.utils import new_name_by_pattern |
| 16 | +from invesalius.net.utils import download_url_to_file | |
| 17 | +from invesalius import inv_paths | |
| 16 | 18 | |
| 17 | 19 | from . import utils |
| 18 | 20 | |
| ... | ... | @@ -64,6 +66,17 @@ def predict_patch(sub_image, patch, nn_model, patch_size=SIZE): |
| 64 | 66 | 0 : ez - iz, 0 : ey - iy, 0 : ex - ix |
| 65 | 67 | ] |
| 66 | 68 | |
| 69 | +def predict_patch_torch(sub_image, patch, nn_model, device, patch_size=SIZE): | |
| 70 | + import torch | |
| 71 | + with torch.no_grad(): | |
| 72 | + (iz, ez), (iy, ey), (ix, ex) = patch | |
| 73 | + sub_mask = nn_model( | |
| 74 | + torch.from_numpy(sub_image.reshape(1, 1, patch_size, patch_size, patch_size)).to(device) | |
| 75 | + ).cpu().numpy() | |
| 76 | + return sub_mask.reshape(patch_size, patch_size, patch_size)[ | |
| 77 | + 0 : ez - iz, 0 : ey - iy, 0 : ex - ix | |
| 78 | + ] | |
| 79 | + | |
| 67 | 80 | |
| 68 | 81 | def brain_segment(image, probability_array, comm_array): |
| 69 | 82 | import keras |
| ... | ... | @@ -89,6 +102,42 @@ def brain_segment(image, probability_array, comm_array): |
| 89 | 102 | comm_array[0] = np.Inf |
| 90 | 103 | |
| 91 | 104 | |
| 105 | +def download_callback(comm_array): | |
| 106 | + def _download_callback(value): | |
| 107 | + comm_array[0] = value | |
| 108 | + return _download_callback | |
| 109 | + | |
| 110 | +def brain_segment_torch(image, device_id, probability_array, comm_array): | |
| 111 | + import torch | |
| 112 | + from .model import Unet3D | |
| 113 | + device = torch.device(device_id) | |
| 114 | + state_dict_file = inv_paths.USER_DL_WEIGHTS.joinpath("brain_mri_t1.pt") | |
| 115 | + if not state_dict_file.exists(): | |
| 116 | + download_url_to_file( | |
| 117 | + "https://github.com/tfmoraes/deepbrain_torch/releases/download/v1.1.0/weights.pt", | |
| 118 | + state_dict_file, | |
| 119 | + "194b0305947c9326eeee9da34ada728435a13c7b24015cbd95971097fc178f22", | |
| 120 | + download_callback(comm_array) | |
| 121 | + ) | |
| 122 | + state_dict = torch.load(str(state_dict_file)) | |
| 123 | + model = Unet3D() | |
| 124 | + model.load_state_dict(state_dict["model_state_dict"]) | |
| 125 | + model.to(device) | |
| 126 | + model.eval() | |
| 127 | + | |
| 128 | + image = imagedata_utils.image_normalize(image, 0.0, 1.0, output_dtype=np.float32) | |
| 129 | + sums = np.zeros_like(image) | |
| 130 | + # segmenting by patches | |
| 131 | + for completion, sub_image, patch in gen_patches(image, SIZE, OVERLAP): | |
| 132 | + comm_array[0] = completion | |
| 133 | + (iz, ez), (iy, ey), (ix, ex) = patch | |
| 134 | + sub_mask = predict_patch_torch(sub_image, patch, model, device, SIZE) | |
| 135 | + probability_array[iz:ez, iy:ey, ix:ex] += sub_mask | |
| 136 | + sums[iz:ez, iy:ey, ix:ex] += 1 | |
| 137 | + | |
| 138 | + probability_array /= sums | |
| 139 | + comm_array[0] = np.Inf | |
| 140 | + | |
| 92 | 141 | ctx = multiprocessing.get_context('spawn') |
| 93 | 142 | class SegmentProcess(ctx.Process): |
| 94 | 143 | def __init__(self, image, create_new_mask, backend, device_id, use_gpu, apply_wwwl=False, window_width=255, window_level=127): |
| ... | ... | @@ -138,8 +187,7 @@ class SegmentProcess(ctx.Process): |
| 138 | 187 | mode="r", |
| 139 | 188 | ) |
| 140 | 189 | |
| 141 | - print(image.min(), image.max()) | |
| 142 | - if self.apply_segment_threshold: | |
| 190 | + if self.apply_wwwl: | |
| 143 | 191 | print("Applying window level") |
| 144 | 192 | image = get_LUT_value(image, self.window_width, self.window_level) |
| 145 | 193 | |
| ... | ... | @@ -153,8 +201,11 @@ class SegmentProcess(ctx.Process): |
| 153 | 201 | self._comm_array_filename, dtype=np.float32, shape=(1,), mode="r+" |
| 154 | 202 | ) |
| 155 | 203 | |
| 156 | - utils.prepare_ambient(self.backend, self.device_id, self.use_gpu) | |
| 157 | - brain_segment(image, probability_array, comm_array) | |
| 204 | + if self.backend.lower() == "pytorch": | |
| 205 | + brain_segment_torch(image, self.device_id, probability_array, comm_array) | |
| 206 | + else: | |
| 207 | + utils.prepare_ambient(self.backend, self.device_id, self.use_gpu) | |
| 208 | + brain_segment(image, probability_array, comm_array) | |
| 158 | 209 | |
| 159 | 210 | @property |
| 160 | 211 | def exception(self): | ... | ... |
requirements.txt