Commit a406f61e7cce53232cda3932c8981e993deb61fb

Authored by Thiago Franco de Moraes
Committed by GitHub
1 parent 9460b9cd
Exists in master

Add pytorch backend to brain segmentation (#365)

* added pytorch to requirements.txt and showing in GUI

* segmenting using pytorch

* Remove debug prints

* Added an function to download a file and save it localy

* Downloading the weight using the developed function to download file
invesalius/gui/brain_seg_dialog.py
... ... @@ -22,6 +22,22 @@ HAS_THEANO = bool(importlib.util.find_spec("theano"))
22 22 HAS_PLAIDML = bool(importlib.util.find_spec("plaidml"))
23 23 PLAIDML_DEVICES = {}
24 24  
  25 +try:
  26 + import torch
  27 + HAS_TORCH = True
  28 +except ImportError:
  29 + HAS_TORCH = False
  30 +
  31 +if HAS_TORCH:
  32 + TORCH_DEVICES = {}
  33 + if torch.cuda.is_available():
  34 + for i in range(torch.cuda.device_count()):
  35 + name = torch.cuda.get_device_name()
  36 + device_id = f'cuda:{i}'
  37 + TORCH_DEVICES[name] = device_id
  38 + TORCH_DEVICES['CPU'] = 'cpu'
  39 +
  40 +
25 41  
26 42 if HAS_PLAIDML:
27 43 with multiprocessing.Pool(1) as p:
... ... @@ -43,12 +59,15 @@ class BrainSegmenterDialog(wx.Dialog):
43 59 style=wx.DEFAULT_DIALOG_STYLE | wx.FRAME_FLOAT_ON_PARENT,
44 60 )
45 61 backends = []
  62 + if HAS_TORCH:
  63 + backends.append("Pytorch")
46 64 if HAS_PLAIDML:
47 65 backends.append("PlaidML")
48 66 if HAS_THEANO:
49 67 backends.append("Theano")
50 68 # self.segmenter = segment.BrainSegmenter()
51 69 # self.pg_dialog = None
  70 + self.torch_devices = TORCH_DEVICES
52 71 self.plaidml_devices = PLAIDML_DEVICES
53 72  
54 73 self.ps = None
... ... @@ -65,13 +84,19 @@ class BrainSegmenterDialog(wx.Dialog):
65 84 w, h = self.CalcSizeFromTextSize("MM" * (1 + max(len(i) for i in backends)))
66 85 self.cb_backends.SetMinClientSize((w, -1))
67 86 self.chk_use_gpu = wx.CheckBox(self, wx.ID_ANY, _("Use GPU"))
68   - if HAS_PLAIDML:
  87 + if HAS_TORCH or HAS_PLAIDML:
  88 + if HAS_TORCH:
  89 + choices = list(self.torch_devices.keys())
  90 + value = choices[0]
  91 + else:
  92 + choices = list(self.plaidml_devices.keys())
  93 + value = choices[0]
69 94 self.lbl_device = wx.StaticText(self, -1, _("Device"))
70 95 self.cb_devices = wx.ComboBox(
71 96 self,
72 97 wx.ID_ANY,
73   - choices=list(self.plaidml_devices.keys()),
74   - value=list(self.plaidml_devices.keys())[0],
  98 + choices=choices,
  99 + value=value,
75 100 style=wx.CB_DROPDOWN | wx.CB_READONLY,
76 101 )
77 102 self.sld_threshold = wx.Slider(self, wx.ID_ANY, 75, 0, 100)
... ... @@ -109,7 +134,7 @@ class BrainSegmenterDialog(wx.Dialog):
109 134 main_sizer.Add(sizer_backends, 0, wx.ALL | wx.EXPAND, 5)
110 135 main_sizer.Add(self.chk_use_gpu, 0, wx.ALL, 5)
111 136 sizer_devices = wx.BoxSizer(wx.HORIZONTAL)
112   - if HAS_PLAIDML:
  137 + if HAS_TORCH or HAS_PLAIDML:
113 138 sizer_devices.Add(self.lbl_device, 0, wx.ALIGN_CENTER, 0)
114 139 sizer_devices.Add(self.cb_devices, 1, wx.LEFT, 5)
115 140 main_sizer.Add(sizer_devices, 0, wx.ALL | wx.EXPAND, 5)
... ... @@ -177,8 +202,21 @@ class BrainSegmenterDialog(wx.Dialog):
177 202 return width, height
178 203  
179 204 def OnSetBackend(self, evt=None):
180   - if self.cb_backends.GetValue().lower() == "plaidml":
  205 + if self.cb_backends.GetValue().lower() == "pytorch":
  206 + if HAS_TORCH:
  207 + choices = list(self.torch_devices.keys())
  208 + self.cb_devices.Clear()
  209 + self.cb_devices.SetItems(choices)
  210 + self.cb_devices.SetValue(choices[0])
  211 + self.lbl_device.Show()
  212 + self.cb_devices.Show()
  213 + self.chk_use_gpu.Hide()
  214 + elif self.cb_backends.GetValue().lower() == "plaidml":
181 215 if HAS_PLAIDML:
  216 + choices = list(self.plaidml_devices.keys())
  217 + self.cb_devices.Clear()
  218 + self.cb_devices.SetItems(choices)
  219 + self.cb_devices.SetValue(choices[0])
182 220 self.lbl_device.Show()
183 221 self.cb_devices.Show()
184 222 self.chk_use_gpu.Hide()
... ... @@ -216,10 +254,16 @@ class BrainSegmenterDialog(wx.Dialog):
216 254 self.elapsed_time_timer.Start(1000)
217 255 image = slc.Slice().matrix
218 256 backend = self.cb_backends.GetValue()
219   - try:
220   - device_id = self.plaidml_devices[self.cb_devices.GetValue()]
221   - except (KeyError, AttributeError):
222   - device_id = "llvm_cpu.0"
  257 + if backend.lower() == "pytorch":
  258 + try:
  259 + device_id = self.torch_devices[self.cb_devices.GetValue()]
  260 + except (KeyError, AttributeError):
  261 + device_id = "cpu"
  262 + else:
  263 + try:
  264 + device_id = self.plaidml_devices[self.cb_devices.GetValue()]
  265 + except (KeyError, AttributeError):
  266 + device_id = "llvm_cpu.0"
223 267 apply_wwwl = self.chk_apply_wwwl.GetValue()
224 268 create_new_mask = self.chk_new_mask.GetValue()
225 269 use_gpu = self.chk_use_gpu.GetValue()
... ...
invesalius/inv_paths.py
... ... @@ -27,6 +27,7 @@ CONF_DIR = pathlib.Path(os.environ.get("XDG_CONFIG_HOME", USER_DIR.joinpath(".co
27 27 USER_INV_DIR = CONF_DIR.joinpath("invesalius")
28 28 USER_PRESET_DIR = USER_INV_DIR.joinpath("presets")
29 29 USER_LOG_DIR = USER_INV_DIR.joinpath("logs")
  30 +USER_DL_WEIGHTS = USER_INV_DIR.joinpath("deep_learning/weights/")
30 31 USER_RAYCASTING_PRESETS_DIRECTORY = USER_PRESET_DIR.joinpath("raycasting")
31 32 TEMP_DIR = tempfile.gettempdir()
32 33  
... ... @@ -97,6 +98,7 @@ def create_conf_folders():
97 98 USER_INV_DIR.mkdir(parents=True, exist_ok=True)
98 99 USER_PRESET_DIR.mkdir(parents=True, exist_ok=True)
99 100 USER_LOG_DIR.mkdir(parents=True, exist_ok=True)
  101 + USER_DL_WEIGHTS.mkdir(parents=True, exist_ok=True)
100 102 USER_PLUGINS_DIRECTORY.mkdir(parents=True, exist_ok=True)
101 103  
102 104  
... ...
invesalius/net/utils.py 0 → 100644
... ... @@ -0,0 +1,48 @@
  1 +from urllib.error import HTTPError
  2 +from urllib.request import urlopen, Request
  3 +from urllib.parse import urlparse
  4 +import pathlib
  5 +import tempfile
  6 +import typing
  7 +import hashlib
  8 +import os
  9 +import shutil
  10 +
  11 +def download_url_to_file(url: str, dst: pathlib.Path, hash: str = None, callback: typing.Callable[[float], None] = None):
  12 + file_size = None
  13 + total_downloaded = 0
  14 + if hash is not None:
  15 + calc_hash = hashlib.sha256()
  16 + req = Request(url)
  17 + response = urlopen(req)
  18 + meta = response.info()
  19 + if hasattr(meta, "getheaders"):
  20 + content_length = meta.getheaders("Content-Length")
  21 + else:
  22 + content_length = meta.get_all("Content-Length")
  23 +
  24 + if content_length is not None and len(content_length) > 0:
  25 + file_size = int(content_length[0])
  26 + dst.parent.mkdir(parents=True, exist_ok=True)
  27 + f = tempfile.NamedTemporaryFile(delete=False, dir=dst.parent)
  28 + try:
  29 + while True:
  30 + buffer = response.read(8192)
  31 + if len(buffer) == 0:
  32 + break
  33 + total_downloaded += len(buffer)
  34 + f.write(buffer)
  35 + if hash:
  36 + calc_hash.update(buffer)
  37 + if callback is not None:
  38 + callback(100 * total_downloaded/file_size)
  39 + f.close()
  40 + if hash is not None:
  41 + digest = calc_hash.hexdigest()
  42 + if digest != hash:
  43 + raise RuntimeError(f'Invalid hash value (expected "{hash}", got "{digest}")')
  44 + shutil.move(f.name, dst)
  45 + finally:
  46 + f.close()
  47 + if os.path.exists(f.name):
  48 + os.remove(f.name)
... ...
invesalius/segmentation/brain/model.py 0 → 100644
... ... @@ -0,0 +1,149 @@
  1 +from collections import OrderedDict
  2 +
  3 +import torch
  4 +import torch.nn as nn
  5 +
  6 +SIZE = 48
  7 +
  8 +class Unet3D(nn.Module):
  9 + # Based on https://github.com/mateuszbuda/brain-segmentation-pytorch/blob/master/unet.py
  10 + def __init__(self, in_channels=1, out_channels=1, init_features=8):
  11 + super().__init__()
  12 + features = init_features
  13 +
  14 + self.encoder1 = self._block(
  15 + in_channels, features=features, padding=2, name="enc1"
  16 + )
  17 + self.pool1 = nn.MaxPool3d(kernel_size=2, stride=2)
  18 +
  19 + self.encoder2 = self._block(
  20 + features, features=features * 2, padding=2, name="enc2"
  21 + )
  22 + self.pool2 = nn.MaxPool3d(kernel_size=2, stride=2)
  23 +
  24 + self.encoder3 = self._block(
  25 + features * 2, features=features * 4, padding=2, name="enc3"
  26 + )
  27 + self.pool3 = nn.MaxPool3d(kernel_size=2, stride=2)
  28 +
  29 + self.encoder4 = self._block(
  30 + features * 4, features=features * 8, padding=2, name="enc4"
  31 + )
  32 + self.pool4 = nn.MaxPool3d(kernel_size=2, stride=2)
  33 +
  34 + self.bottleneck = self._block(
  35 + features * 8, features=features * 16, padding=2, name="bottleneck"
  36 + )
  37 +
  38 + self.upconv4 = nn.ConvTranspose3d(
  39 + features * 16, features * 8, kernel_size=4, stride=2, padding=1
  40 + )
  41 + self.decoder4 = self._block(
  42 + features * 16, features=features * 8, padding=2, name="dec4"
  43 + )
  44 +
  45 + self.upconv3 = nn.ConvTranspose3d(
  46 + features * 8, features * 4, kernel_size=4, stride=2, padding=1
  47 + )
  48 + self.decoder3 = self._block(
  49 + features * 8, features=features * 4, padding=2, name="dec4"
  50 + )
  51 +
  52 + self.upconv2 = nn.ConvTranspose3d(
  53 + features * 4, features * 2, kernel_size=4, stride=2, padding=1
  54 + )
  55 + self.decoder2 = self._block(
  56 + features * 4, features=features * 2, padding=2, name="dec4"
  57 + )
  58 +
  59 + self.upconv1 = nn.ConvTranspose3d(
  60 + features * 2, features, kernel_size=4, stride=2, padding=1
  61 + )
  62 + self.decoder1 = self._block(
  63 + features * 2, features=features, padding=2, name="dec4"
  64 + )
  65 +
  66 + self.conv = nn.Conv3d(
  67 + in_channels=features, out_channels=out_channels, kernel_size=1
  68 + )
  69 +
  70 + def forward(self, img):
  71 + enc1 = self.encoder1(img)
  72 + enc2 = self.encoder2(self.pool1(enc1))
  73 + enc3 = self.encoder3(self.pool2(enc2))
  74 + enc4 = self.encoder4(self.pool3(enc3))
  75 +
  76 + bottleneck = self.bottleneck(self.pool4(enc4))
  77 +
  78 + upconv4 = self.upconv4(bottleneck)
  79 + dec4 = torch.cat((upconv4, enc4), dim=1)
  80 + dec4 = self.decoder4(dec4)
  81 +
  82 + upconv3 = self.upconv3(dec4)
  83 + dec3 = torch.cat((upconv3, enc3), dim=1)
  84 + dec3 = self.decoder3(dec3)
  85 +
  86 + upconv2 = self.upconv2(dec3)
  87 + dec2 = torch.cat((upconv2, enc2), dim=1)
  88 + dec2 = self.decoder2(dec2)
  89 +
  90 + upconv1 = self.upconv1(dec2)
  91 + dec1 = torch.cat((upconv1, enc1), dim=1)
  92 + dec1 = self.decoder1(dec1)
  93 +
  94 + conv = self.conv(dec1)
  95 +
  96 + sigmoid = torch.sigmoid(conv)
  97 +
  98 + return sigmoid
  99 +
  100 + def _block(self, in_channels, features, padding=1, kernel_size=5, name="block"):
  101 + return nn.Sequential(
  102 + OrderedDict(
  103 + (
  104 + (
  105 + f"{name}_conv1",
  106 + nn.Conv3d(
  107 + in_channels=in_channels,
  108 + out_channels=features,
  109 + kernel_size=kernel_size,
  110 + padding=padding,
  111 + bias=True,
  112 + ),
  113 + ),
  114 + (f"{name}_norm1", nn.BatchNorm3d(num_features=features)),
  115 + (f"{name}_relu1", nn.ReLU(inplace=True)),
  116 + (
  117 + f"{name}_conv2",
  118 + nn.Conv3d(
  119 + in_channels=features,
  120 + out_channels=features,
  121 + kernel_size=kernel_size,
  122 + padding=padding,
  123 + bias=True,
  124 + ),
  125 + ),
  126 + (f"{name}_norm2", nn.BatchNorm3d(num_features=features)),
  127 + (f"{name}_relu2", nn.ReLU(inplace=True)),
  128 + )
  129 + )
  130 + )
  131 +
  132 +
  133 +def main():
  134 + import torchviz
  135 + dev = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
  136 + model = Unet3D()
  137 + model.to(dev)
  138 + model.eval()
  139 + print(next(model.parameters()).is_cuda) # True
  140 + img = torch.randn(1, SIZE, SIZE, SIZE, 1).to(dev)
  141 + out = model(img)
  142 + dot = torchviz.make_dot(out, params=dict(model.named_parameters()), show_attrs=True, show_saved=True)
  143 + dot.render("unet", format="png")
  144 + torch.save(model, "model.pth")
  145 + print(dot)
  146 +
  147 +
  148 +if __name__ == "__main__":
  149 + main()
... ...
invesalius/segmentation/brain/segment.py
... ... @@ -13,6 +13,8 @@ import invesalius.data.slice_ as slc
13 13 from invesalius import inv_paths
14 14 from invesalius.data import imagedata_utils
15 15 from invesalius.utils import new_name_by_pattern
  16 +from invesalius.net.utils import download_url_to_file
  17 +from invesalius import inv_paths
16 18  
17 19 from . import utils
18 20  
... ... @@ -64,6 +66,17 @@ def predict_patch(sub_image, patch, nn_model, patch_size=SIZE):
64 66 0 : ez - iz, 0 : ey - iy, 0 : ex - ix
65 67 ]
66 68  
  69 +def predict_patch_torch(sub_image, patch, nn_model, device, patch_size=SIZE):
  70 + import torch
  71 + with torch.no_grad():
  72 + (iz, ez), (iy, ey), (ix, ex) = patch
  73 + sub_mask = nn_model(
  74 + torch.from_numpy(sub_image.reshape(1, 1, patch_size, patch_size, patch_size)).to(device)
  75 + ).cpu().numpy()
  76 + return sub_mask.reshape(patch_size, patch_size, patch_size)[
  77 + 0 : ez - iz, 0 : ey - iy, 0 : ex - ix
  78 + ]
  79 +
67 80  
68 81 def brain_segment(image, probability_array, comm_array):
69 82 import keras
... ... @@ -89,6 +102,42 @@ def brain_segment(image, probability_array, comm_array):
89 102 comm_array[0] = np.Inf
90 103  
91 104  
  105 +def download_callback(comm_array):
  106 + def _download_callback(value):
  107 + comm_array[0] = value
  108 + return _download_callback
  109 +
  110 +def brain_segment_torch(image, device_id, probability_array, comm_array):
  111 + import torch
  112 + from .model import Unet3D
  113 + device = torch.device(device_id)
  114 + state_dict_file = inv_paths.USER_DL_WEIGHTS.joinpath("brain_mri_t1.pt")
  115 + if not state_dict_file.exists():
  116 + download_url_to_file(
  117 + "https://github.com/tfmoraes/deepbrain_torch/releases/download/v1.1.0/weights.pt",
  118 + state_dict_file,
  119 + "194b0305947c9326eeee9da34ada728435a13c7b24015cbd95971097fc178f22",
  120 + download_callback(comm_array)
  121 + )
  122 + state_dict = torch.load(str(state_dict_file))
  123 + model = Unet3D()
  124 + model.load_state_dict(state_dict["model_state_dict"])
  125 + model.to(device)
  126 + model.eval()
  127 +
  128 + image = imagedata_utils.image_normalize(image, 0.0, 1.0, output_dtype=np.float32)
  129 + sums = np.zeros_like(image)
  130 + # segmenting by patches
  131 + for completion, sub_image, patch in gen_patches(image, SIZE, OVERLAP):
  132 + comm_array[0] = completion
  133 + (iz, ez), (iy, ey), (ix, ex) = patch
  134 + sub_mask = predict_patch_torch(sub_image, patch, model, device, SIZE)
  135 + probability_array[iz:ez, iy:ey, ix:ex] += sub_mask
  136 + sums[iz:ez, iy:ey, ix:ex] += 1
  137 +
  138 + probability_array /= sums
  139 + comm_array[0] = np.Inf
  140 +
92 141 ctx = multiprocessing.get_context('spawn')
93 142 class SegmentProcess(ctx.Process):
94 143 def __init__(self, image, create_new_mask, backend, device_id, use_gpu, apply_wwwl=False, window_width=255, window_level=127):
... ... @@ -138,8 +187,7 @@ class SegmentProcess(ctx.Process):
138 187 mode="r",
139 188 )
140 189  
141   - print(image.min(), image.max())
142   - if self.apply_segment_threshold:
  190 + if self.apply_wwwl:
143 191 print("Applying window level")
144 192 image = get_LUT_value(image, self.window_width, self.window_level)
145 193  
... ... @@ -153,8 +201,11 @@ class SegmentProcess(ctx.Process):
153 201 self._comm_array_filename, dtype=np.float32, shape=(1,), mode="r+"
154 202 )
155 203  
156   - utils.prepare_ambient(self.backend, self.device_id, self.use_gpu)
157   - brain_segment(image, probability_array, comm_array)
  204 + if self.backend.lower() == "pytorch":
  205 + brain_segment_torch(image, self.device_id, probability_array, comm_array)
  206 + else:
  207 + utils.prepare_ambient(self.backend, self.device_id, self.use_gpu)
  208 + brain_segment(image, probability_array, comm_array)
158 209  
159 210 @property
160 211 def exception(self):
... ...
requirements.txt
... ... @@ -15,3 +15,4 @@ scipy==1.7.1
15 15 vtk==9.0.3
16 16 wxPython==4.1.1
17 17 Theano==1.0.5
  18 +torch==1.9.1
... ...