From a406f61e7cce53232cda3932c8981e993deb61fb Mon Sep 17 00:00:00 2001 From: Thiago Franco de Moraes Date: Fri, 8 Oct 2021 16:16:53 -0300 Subject: [PATCH] Add pytorch backend to brain segmentation (#365) --- invesalius/gui/brain_seg_dialog.py | 62 +++++++++++++++++++++++++++++++++++++++++++++++++++++--------- invesalius/inv_paths.py | 2 ++ invesalius/net/utils.py | 48 ++++++++++++++++++++++++++++++++++++++++++++++++ invesalius/segmentation/brain/model.py | 149 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ invesalius/segmentation/brain/segment.py | 59 +++++++++++++++++++++++++++++++++++++++++++++++++++++++---- requirements.txt | 1 + 6 files changed, 308 insertions(+), 13 deletions(-) create mode 100644 invesalius/net/utils.py create mode 100644 invesalius/segmentation/brain/model.py diff --git a/invesalius/gui/brain_seg_dialog.py b/invesalius/gui/brain_seg_dialog.py index 3f2f91b..577c92f 100644 --- a/invesalius/gui/brain_seg_dialog.py +++ b/invesalius/gui/brain_seg_dialog.py @@ -22,6 +22,22 @@ HAS_THEANO = bool(importlib.util.find_spec("theano")) HAS_PLAIDML = bool(importlib.util.find_spec("plaidml")) PLAIDML_DEVICES = {} +try: + import torch + HAS_TORCH = True +except ImportError: + HAS_TORCH = False + +if HAS_TORCH: + TORCH_DEVICES = {} + if torch.cuda.is_available(): + for i in range(torch.cuda.device_count()): + name = torch.cuda.get_device_name() + device_id = f'cuda:{i}' + TORCH_DEVICES[name] = device_id + TORCH_DEVICES['CPU'] = 'cpu' + + if HAS_PLAIDML: with multiprocessing.Pool(1) as p: @@ -43,12 +59,15 @@ class BrainSegmenterDialog(wx.Dialog): style=wx.DEFAULT_DIALOG_STYLE | wx.FRAME_FLOAT_ON_PARENT, ) backends = [] + if HAS_TORCH: + backends.append("Pytorch") if HAS_PLAIDML: backends.append("PlaidML") if HAS_THEANO: backends.append("Theano") # self.segmenter = segment.BrainSegmenter() # self.pg_dialog = None + self.torch_devices = TORCH_DEVICES self.plaidml_devices = PLAIDML_DEVICES self.ps = None @@ -65,13 +84,19 @@ class BrainSegmenterDialog(wx.Dialog): w, h = self.CalcSizeFromTextSize("MM" * (1 + max(len(i) for i in backends))) self.cb_backends.SetMinClientSize((w, -1)) self.chk_use_gpu = wx.CheckBox(self, wx.ID_ANY, _("Use GPU")) - if HAS_PLAIDML: + if HAS_TORCH or HAS_PLAIDML: + if HAS_TORCH: + choices = list(self.torch_devices.keys()) + value = choices[0] + else: + choices = list(self.plaidml_devices.keys()) + value = choices[0] self.lbl_device = wx.StaticText(self, -1, _("Device")) self.cb_devices = wx.ComboBox( self, wx.ID_ANY, - choices=list(self.plaidml_devices.keys()), - value=list(self.plaidml_devices.keys())[0], + choices=choices, + value=value, style=wx.CB_DROPDOWN | wx.CB_READONLY, ) self.sld_threshold = wx.Slider(self, wx.ID_ANY, 75, 0, 100) @@ -109,7 +134,7 @@ class BrainSegmenterDialog(wx.Dialog): main_sizer.Add(sizer_backends, 0, wx.ALL | wx.EXPAND, 5) main_sizer.Add(self.chk_use_gpu, 0, wx.ALL, 5) sizer_devices = wx.BoxSizer(wx.HORIZONTAL) - if HAS_PLAIDML: + if HAS_TORCH or HAS_PLAIDML: sizer_devices.Add(self.lbl_device, 0, wx.ALIGN_CENTER, 0) sizer_devices.Add(self.cb_devices, 1, wx.LEFT, 5) main_sizer.Add(sizer_devices, 0, wx.ALL | wx.EXPAND, 5) @@ -177,8 +202,21 @@ class BrainSegmenterDialog(wx.Dialog): return width, height def OnSetBackend(self, evt=None): - if self.cb_backends.GetValue().lower() == "plaidml": + if self.cb_backends.GetValue().lower() == "pytorch": + if HAS_TORCH: + choices = list(self.torch_devices.keys()) + self.cb_devices.Clear() + self.cb_devices.SetItems(choices) + self.cb_devices.SetValue(choices[0]) + self.lbl_device.Show() + self.cb_devices.Show() + self.chk_use_gpu.Hide() + elif self.cb_backends.GetValue().lower() == "plaidml": if HAS_PLAIDML: + choices = list(self.plaidml_devices.keys()) + self.cb_devices.Clear() + self.cb_devices.SetItems(choices) + self.cb_devices.SetValue(choices[0]) self.lbl_device.Show() self.cb_devices.Show() self.chk_use_gpu.Hide() @@ -216,10 +254,16 @@ class BrainSegmenterDialog(wx.Dialog): self.elapsed_time_timer.Start(1000) image = slc.Slice().matrix backend = self.cb_backends.GetValue() - try: - device_id = self.plaidml_devices[self.cb_devices.GetValue()] - except (KeyError, AttributeError): - device_id = "llvm_cpu.0" + if backend.lower() == "pytorch": + try: + device_id = self.torch_devices[self.cb_devices.GetValue()] + except (KeyError, AttributeError): + device_id = "cpu" + else: + try: + device_id = self.plaidml_devices[self.cb_devices.GetValue()] + except (KeyError, AttributeError): + device_id = "llvm_cpu.0" apply_wwwl = self.chk_apply_wwwl.GetValue() create_new_mask = self.chk_new_mask.GetValue() use_gpu = self.chk_use_gpu.GetValue() diff --git a/invesalius/inv_paths.py b/invesalius/inv_paths.py index d75b7c0..85a43a7 100644 --- a/invesalius/inv_paths.py +++ b/invesalius/inv_paths.py @@ -27,6 +27,7 @@ CONF_DIR = pathlib.Path(os.environ.get("XDG_CONFIG_HOME", USER_DIR.joinpath(".co USER_INV_DIR = CONF_DIR.joinpath("invesalius") USER_PRESET_DIR = USER_INV_DIR.joinpath("presets") USER_LOG_DIR = USER_INV_DIR.joinpath("logs") +USER_DL_WEIGHTS = USER_INV_DIR.joinpath("deep_learning/weights/") USER_RAYCASTING_PRESETS_DIRECTORY = USER_PRESET_DIR.joinpath("raycasting") TEMP_DIR = tempfile.gettempdir() @@ -97,6 +98,7 @@ def create_conf_folders(): USER_INV_DIR.mkdir(parents=True, exist_ok=True) USER_PRESET_DIR.mkdir(parents=True, exist_ok=True) USER_LOG_DIR.mkdir(parents=True, exist_ok=True) + USER_DL_WEIGHTS.mkdir(parents=True, exist_ok=True) USER_PLUGINS_DIRECTORY.mkdir(parents=True, exist_ok=True) diff --git a/invesalius/net/utils.py b/invesalius/net/utils.py new file mode 100644 index 0000000..a02a9be --- /dev/null +++ b/invesalius/net/utils.py @@ -0,0 +1,48 @@ +from urllib.error import HTTPError +from urllib.request import urlopen, Request +from urllib.parse import urlparse +import pathlib +import tempfile +import typing +import hashlib +import os +import shutil + +def download_url_to_file(url: str, dst: pathlib.Path, hash: str = None, callback: typing.Callable[[float], None] = None): + file_size = None + total_downloaded = 0 + if hash is not None: + calc_hash = hashlib.sha256() + req = Request(url) + response = urlopen(req) + meta = response.info() + if hasattr(meta, "getheaders"): + content_length = meta.getheaders("Content-Length") + else: + content_length = meta.get_all("Content-Length") + + if content_length is not None and len(content_length) > 0: + file_size = int(content_length[0]) + dst.parent.mkdir(parents=True, exist_ok=True) + f = tempfile.NamedTemporaryFile(delete=False, dir=dst.parent) + try: + while True: + buffer = response.read(8192) + if len(buffer) == 0: + break + total_downloaded += len(buffer) + f.write(buffer) + if hash: + calc_hash.update(buffer) + if callback is not None: + callback(100 * total_downloaded/file_size) + f.close() + if hash is not None: + digest = calc_hash.hexdigest() + if digest != hash: + raise RuntimeError(f'Invalid hash value (expected "{hash}", got "{digest}")') + shutil.move(f.name, dst) + finally: + f.close() + if os.path.exists(f.name): + os.remove(f.name) diff --git a/invesalius/segmentation/brain/model.py b/invesalius/segmentation/brain/model.py new file mode 100644 index 0000000..29e3902 --- /dev/null +++ b/invesalius/segmentation/brain/model.py @@ -0,0 +1,149 @@ +from collections import OrderedDict + +import torch +import torch.nn as nn + +SIZE = 48 + +class Unet3D(nn.Module): + # Based on https://github.com/mateuszbuda/brain-segmentation-pytorch/blob/master/unet.py + def __init__(self, in_channels=1, out_channels=1, init_features=8): + super().__init__() + features = init_features + + self.encoder1 = self._block( + in_channels, features=features, padding=2, name="enc1" + ) + self.pool1 = nn.MaxPool3d(kernel_size=2, stride=2) + + self.encoder2 = self._block( + features, features=features * 2, padding=2, name="enc2" + ) + self.pool2 = nn.MaxPool3d(kernel_size=2, stride=2) + + self.encoder3 = self._block( + features * 2, features=features * 4, padding=2, name="enc3" + ) + self.pool3 = nn.MaxPool3d(kernel_size=2, stride=2) + + self.encoder4 = self._block( + features * 4, features=features * 8, padding=2, name="enc4" + ) + self.pool4 = nn.MaxPool3d(kernel_size=2, stride=2) + + self.bottleneck = self._block( + features * 8, features=features * 16, padding=2, name="bottleneck" + ) + + self.upconv4 = nn.ConvTranspose3d( + features * 16, features * 8, kernel_size=4, stride=2, padding=1 + ) + self.decoder4 = self._block( + features * 16, features=features * 8, padding=2, name="dec4" + ) + + self.upconv3 = nn.ConvTranspose3d( + features * 8, features * 4, kernel_size=4, stride=2, padding=1 + ) + self.decoder3 = self._block( + features * 8, features=features * 4, padding=2, name="dec4" + ) + + self.upconv2 = nn.ConvTranspose3d( + features * 4, features * 2, kernel_size=4, stride=2, padding=1 + ) + self.decoder2 = self._block( + features * 4, features=features * 2, padding=2, name="dec4" + ) + + self.upconv1 = nn.ConvTranspose3d( + features * 2, features, kernel_size=4, stride=2, padding=1 + ) + self.decoder1 = self._block( + features * 2, features=features, padding=2, name="dec4" + ) + + self.conv = nn.Conv3d( + in_channels=features, out_channels=out_channels, kernel_size=1 + ) + + def forward(self, img): + enc1 = self.encoder1(img) + enc2 = self.encoder2(self.pool1(enc1)) + enc3 = self.encoder3(self.pool2(enc2)) + enc4 = self.encoder4(self.pool3(enc3)) + + bottleneck = self.bottleneck(self.pool4(enc4)) + + upconv4 = self.upconv4(bottleneck) + dec4 = torch.cat((upconv4, enc4), dim=1) + dec4 = self.decoder4(dec4) + + upconv3 = self.upconv3(dec4) + dec3 = torch.cat((upconv3, enc3), dim=1) + dec3 = self.decoder3(dec3) + + upconv2 = self.upconv2(dec3) + dec2 = torch.cat((upconv2, enc2), dim=1) + dec2 = self.decoder2(dec2) + + upconv1 = self.upconv1(dec2) + dec1 = torch.cat((upconv1, enc1), dim=1) + dec1 = self.decoder1(dec1) + + conv = self.conv(dec1) + + sigmoid = torch.sigmoid(conv) + + return sigmoid + + def _block(self, in_channels, features, padding=1, kernel_size=5, name="block"): + return nn.Sequential( + OrderedDict( + ( + ( + f"{name}_conv1", + nn.Conv3d( + in_channels=in_channels, + out_channels=features, + kernel_size=kernel_size, + padding=padding, + bias=True, + ), + ), + (f"{name}_norm1", nn.BatchNorm3d(num_features=features)), + (f"{name}_relu1", nn.ReLU(inplace=True)), + ( + f"{name}_conv2", + nn.Conv3d( + in_channels=features, + out_channels=features, + kernel_size=kernel_size, + padding=padding, + bias=True, + ), + ), + (f"{name}_norm2", nn.BatchNorm3d(num_features=features)), + (f"{name}_relu2", nn.ReLU(inplace=True)), + ) + ) + ) + + +def main(): + import torchviz + dev = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + model = Unet3D() + model.to(dev) + model.eval() + print(next(model.parameters()).is_cuda) # True + img = torch.randn(1, SIZE, SIZE, SIZE, 1).to(dev) + out = model(img) + dot = torchviz.make_dot(out, params=dict(model.named_parameters()), show_attrs=True, show_saved=True) + dot.render("unet", format="png") + torch.save(model, "model.pth") + print(dot) + + +if __name__ == "__main__": + main() diff --git a/invesalius/segmentation/brain/segment.py b/invesalius/segmentation/brain/segment.py index 3dd2fd8..f1b398d 100644 --- a/invesalius/segmentation/brain/segment.py +++ b/invesalius/segmentation/brain/segment.py @@ -13,6 +13,8 @@ import invesalius.data.slice_ as slc from invesalius import inv_paths from invesalius.data import imagedata_utils from invesalius.utils import new_name_by_pattern +from invesalius.net.utils import download_url_to_file +from invesalius import inv_paths from . import utils @@ -64,6 +66,17 @@ def predict_patch(sub_image, patch, nn_model, patch_size=SIZE): 0 : ez - iz, 0 : ey - iy, 0 : ex - ix ] +def predict_patch_torch(sub_image, patch, nn_model, device, patch_size=SIZE): + import torch + with torch.no_grad(): + (iz, ez), (iy, ey), (ix, ex) = patch + sub_mask = nn_model( + torch.from_numpy(sub_image.reshape(1, 1, patch_size, patch_size, patch_size)).to(device) + ).cpu().numpy() + return sub_mask.reshape(patch_size, patch_size, patch_size)[ + 0 : ez - iz, 0 : ey - iy, 0 : ex - ix + ] + def brain_segment(image, probability_array, comm_array): import keras @@ -89,6 +102,42 @@ def brain_segment(image, probability_array, comm_array): comm_array[0] = np.Inf +def download_callback(comm_array): + def _download_callback(value): + comm_array[0] = value + return _download_callback + +def brain_segment_torch(image, device_id, probability_array, comm_array): + import torch + from .model import Unet3D + device = torch.device(device_id) + state_dict_file = inv_paths.USER_DL_WEIGHTS.joinpath("brain_mri_t1.pt") + if not state_dict_file.exists(): + download_url_to_file( + "https://github.com/tfmoraes/deepbrain_torch/releases/download/v1.1.0/weights.pt", + state_dict_file, + "194b0305947c9326eeee9da34ada728435a13c7b24015cbd95971097fc178f22", + download_callback(comm_array) + ) + state_dict = torch.load(str(state_dict_file)) + model = Unet3D() + model.load_state_dict(state_dict["model_state_dict"]) + model.to(device) + model.eval() + + image = imagedata_utils.image_normalize(image, 0.0, 1.0, output_dtype=np.float32) + sums = np.zeros_like(image) + # segmenting by patches + for completion, sub_image, patch in gen_patches(image, SIZE, OVERLAP): + comm_array[0] = completion + (iz, ez), (iy, ey), (ix, ex) = patch + sub_mask = predict_patch_torch(sub_image, patch, model, device, SIZE) + probability_array[iz:ez, iy:ey, ix:ex] += sub_mask + sums[iz:ez, iy:ey, ix:ex] += 1 + + probability_array /= sums + comm_array[0] = np.Inf + ctx = multiprocessing.get_context('spawn') class SegmentProcess(ctx.Process): def __init__(self, image, create_new_mask, backend, device_id, use_gpu, apply_wwwl=False, window_width=255, window_level=127): @@ -138,8 +187,7 @@ class SegmentProcess(ctx.Process): mode="r", ) - print(image.min(), image.max()) - if self.apply_segment_threshold: + if self.apply_wwwl: print("Applying window level") image = get_LUT_value(image, self.window_width, self.window_level) @@ -153,8 +201,11 @@ class SegmentProcess(ctx.Process): self._comm_array_filename, dtype=np.float32, shape=(1,), mode="r+" ) - utils.prepare_ambient(self.backend, self.device_id, self.use_gpu) - brain_segment(image, probability_array, comm_array) + if self.backend.lower() == "pytorch": + brain_segment_torch(image, self.device_id, probability_array, comm_array) + else: + utils.prepare_ambient(self.backend, self.device_id, self.use_gpu) + brain_segment(image, probability_array, comm_array) @property def exception(self): diff --git a/requirements.txt b/requirements.txt index 407f6cf..569601f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,3 +15,4 @@ scipy==1.7.1 vtk==9.0.3 wxPython==4.1.1 Theano==1.0.5 +torch==1.9.1 -- libgit2 0.21.2