Commit 91c68b26fbc93e4e77e1bf4df50b612fceb93c47

Authored by Thiago Franco de Moraes
1 parent 13163add
Exists in master

Try to open pytorch mri t1 brain weights in ai/brain_mri_t1/

Showing 1 changed file with 11 additions and 4 deletions   Show diff stats
invesalius/segmentation/brain/segment.py
... ... @@ -100,15 +100,22 @@ def brain_segment_torch(image, device_id, probability_array, comm_array):
100 100 import torch
101 101 from .model import Unet3D
102 102 device = torch.device(device_id)
103   - state_dict_file = inv_paths.USER_DL_WEIGHTS.joinpath("brain_mri_t1.pt")
104   - if not state_dict_file.exists():
  103 + folder = inv_paths.MODELS_DIR.joinpath("brain_mri_t1")
  104 + system_state_dict_file = folder.joinpath("brain_mri_t1.pt")
  105 + user_state_dict_file = inv_paths.USER_DL_WEIGHTS.joinpath("brain_mri_t1.pt")
  106 + if not system_state_dict_file.exists() and not user_state_dict_file.exists():
105 107 download_url_to_file(
106 108 "https://github.com/tfmoraes/deepbrain_torch/releases/download/v1.1.0/weights.pt",
107   - state_dict_file,
  109 + user_state_dict_file,
108 110 "194b0305947c9326eeee9da34ada728435a13c7b24015cbd95971097fc178f22",
109 111 download_callback(comm_array)
110 112 )
111   - state_dict = torch.load(str(state_dict_file))
  113 + if user_state_dict_file.exists():
  114 + state_dict = torch.load(str(user_state_dict_file))
  115 + elif system_state_dict_file.exists():
  116 + state_dict = torch.load(str(system_state_dict_file))
  117 + else:
  118 + raise FileNotFoundError("Weights file not found")
112 119 model = Unet3D()
113 120 model.load_state_dict(state_dict["model_state_dict"])
114 121 model.to(device)
... ...