Commit 91c68b26fbc93e4e77e1bf4df50b612fceb93c47
1 parent
13163add
Exists in
master
Try to open pytorch mri t1 brain weights in ai/brain_mri_t1/
Showing
1 changed file
with
11 additions
and
4 deletions
Show diff stats
invesalius/segmentation/brain/segment.py
@@ -100,15 +100,22 @@ def brain_segment_torch(image, device_id, probability_array, comm_array): | @@ -100,15 +100,22 @@ def brain_segment_torch(image, device_id, probability_array, comm_array): | ||
100 | import torch | 100 | import torch |
101 | from .model import Unet3D | 101 | from .model import Unet3D |
102 | device = torch.device(device_id) | 102 | device = torch.device(device_id) |
103 | - state_dict_file = inv_paths.USER_DL_WEIGHTS.joinpath("brain_mri_t1.pt") | ||
104 | - if not state_dict_file.exists(): | 103 | + folder = inv_paths.MODELS_DIR.joinpath("brain_mri_t1") |
104 | + system_state_dict_file = folder.joinpath("brain_mri_t1.pt") | ||
105 | + user_state_dict_file = inv_paths.USER_DL_WEIGHTS.joinpath("brain_mri_t1.pt") | ||
106 | + if not system_state_dict_file.exists() and not user_state_dict_file.exists(): | ||
105 | download_url_to_file( | 107 | download_url_to_file( |
106 | "https://github.com/tfmoraes/deepbrain_torch/releases/download/v1.1.0/weights.pt", | 108 | "https://github.com/tfmoraes/deepbrain_torch/releases/download/v1.1.0/weights.pt", |
107 | - state_dict_file, | 109 | + user_state_dict_file, |
108 | "194b0305947c9326eeee9da34ada728435a13c7b24015cbd95971097fc178f22", | 110 | "194b0305947c9326eeee9da34ada728435a13c7b24015cbd95971097fc178f22", |
109 | download_callback(comm_array) | 111 | download_callback(comm_array) |
110 | ) | 112 | ) |
111 | - state_dict = torch.load(str(state_dict_file)) | 113 | + if user_state_dict_file.exists(): |
114 | + state_dict = torch.load(str(user_state_dict_file)) | ||
115 | + elif system_state_dict_file.exists(): | ||
116 | + state_dict = torch.load(str(system_state_dict_file)) | ||
117 | + else: | ||
118 | + raise FileNotFoundError("Weights file not found") | ||
112 | model = Unet3D() | 119 | model = Unet3D() |
113 | model.load_state_dict(state_dict["model_state_dict"]) | 120 | model.load_state_dict(state_dict["model_state_dict"]) |
114 | model.to(device) | 121 | model.to(device) |