From 91c68b26fbc93e4e77e1bf4df50b612fceb93c47 Mon Sep 17 00:00:00 2001 From: Thiago Franco de Moraes Date: Tue, 1 Feb 2022 14:32:29 -0300 Subject: [PATCH] Try to open pytorch mri t1 brain weights in ai/brain_mri_t1/ --- invesalius/segmentation/brain/segment.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/invesalius/segmentation/brain/segment.py b/invesalius/segmentation/brain/segment.py index 5280ec6..f6e7b13 100644 --- a/invesalius/segmentation/brain/segment.py +++ b/invesalius/segmentation/brain/segment.py @@ -100,15 +100,22 @@ def brain_segment_torch(image, device_id, probability_array, comm_array): import torch from .model import Unet3D device = torch.device(device_id) - state_dict_file = inv_paths.USER_DL_WEIGHTS.joinpath("brain_mri_t1.pt") - if not state_dict_file.exists(): + folder = inv_paths.MODELS_DIR.joinpath("brain_mri_t1") + system_state_dict_file = folder.joinpath("brain_mri_t1.pt") + user_state_dict_file = inv_paths.USER_DL_WEIGHTS.joinpath("brain_mri_t1.pt") + if not system_state_dict_file.exists() and not user_state_dict_file.exists(): download_url_to_file( "https://github.com/tfmoraes/deepbrain_torch/releases/download/v1.1.0/weights.pt", - state_dict_file, + user_state_dict_file, "194b0305947c9326eeee9da34ada728435a13c7b24015cbd95971097fc178f22", download_callback(comm_array) ) - state_dict = torch.load(str(state_dict_file)) + if user_state_dict_file.exists(): + state_dict = torch.load(str(user_state_dict_file)) + elif system_state_dict_file.exists(): + state_dict = torch.load(str(system_state_dict_file)) + else: + raise FileNotFoundError("Weights file not found") model = Unet3D() model.load_state_dict(state_dict["model_state_dict"]) model.to(device) -- libgit2 0.21.2