diff --git a/invesalius/segmentation/brain/segment.py b/invesalius/segmentation/brain/segment.py index 5280ec6..f6e7b13 100644 --- a/invesalius/segmentation/brain/segment.py +++ b/invesalius/segmentation/brain/segment.py @@ -100,15 +100,22 @@ def brain_segment_torch(image, device_id, probability_array, comm_array): import torch from .model import Unet3D device = torch.device(device_id) - state_dict_file = inv_paths.USER_DL_WEIGHTS.joinpath("brain_mri_t1.pt") - if not state_dict_file.exists(): + folder = inv_paths.MODELS_DIR.joinpath("brain_mri_t1") + system_state_dict_file = folder.joinpath("brain_mri_t1.pt") + user_state_dict_file = inv_paths.USER_DL_WEIGHTS.joinpath("brain_mri_t1.pt") + if not system_state_dict_file.exists() and not user_state_dict_file.exists(): download_url_to_file( "https://github.com/tfmoraes/deepbrain_torch/releases/download/v1.1.0/weights.pt", - state_dict_file, + user_state_dict_file, "194b0305947c9326eeee9da34ada728435a13c7b24015cbd95971097fc178f22", download_callback(comm_array) ) - state_dict = torch.load(str(state_dict_file)) + if user_state_dict_file.exists(): + state_dict = torch.load(str(user_state_dict_file)) + elif system_state_dict_file.exists(): + state_dict = torch.load(str(system_state_dict_file)) + else: + raise FileNotFoundError("Weights file not found") model = Unet3D() model.load_state_dict(state_dict["model_state_dict"]) model.to(device) -- libgit2 0.21.2