Commit 91c68b26fbc93e4e77e1bf4df50b612fceb93c47
1 parent
13163add
Exists in
master
Try to open pytorch mri t1 brain weights in ai/brain_mri_t1/
Showing
1 changed file
with
11 additions
and
4 deletions
Show diff stats
invesalius/segmentation/brain/segment.py
... | ... | @@ -100,15 +100,22 @@ def brain_segment_torch(image, device_id, probability_array, comm_array): |
100 | 100 | import torch |
101 | 101 | from .model import Unet3D |
102 | 102 | device = torch.device(device_id) |
103 | - state_dict_file = inv_paths.USER_DL_WEIGHTS.joinpath("brain_mri_t1.pt") | |
104 | - if not state_dict_file.exists(): | |
103 | + folder = inv_paths.MODELS_DIR.joinpath("brain_mri_t1") | |
104 | + system_state_dict_file = folder.joinpath("brain_mri_t1.pt") | |
105 | + user_state_dict_file = inv_paths.USER_DL_WEIGHTS.joinpath("brain_mri_t1.pt") | |
106 | + if not system_state_dict_file.exists() and not user_state_dict_file.exists(): | |
105 | 107 | download_url_to_file( |
106 | 108 | "https://github.com/tfmoraes/deepbrain_torch/releases/download/v1.1.0/weights.pt", |
107 | - state_dict_file, | |
109 | + user_state_dict_file, | |
108 | 110 | "194b0305947c9326eeee9da34ada728435a13c7b24015cbd95971097fc178f22", |
109 | 111 | download_callback(comm_array) |
110 | 112 | ) |
111 | - state_dict = torch.load(str(state_dict_file)) | |
113 | + if user_state_dict_file.exists(): | |
114 | + state_dict = torch.load(str(user_state_dict_file)) | |
115 | + elif system_state_dict_file.exists(): | |
116 | + state_dict = torch.load(str(system_state_dict_file)) | |
117 | + else: | |
118 | + raise FileNotFoundError("Weights file not found") | |
112 | 119 | model = Unet3D() |
113 | 120 | model.load_state_dict(state_dict["model_state_dict"]) |
114 | 121 | model.to(device) | ... | ... |