Fix type of HookedTransformerConfig.device#1230
Fix type of HookedTransformerConfig.device#1230brendanlong wants to merge 1 commit intoTransformerLensOrg:mainfrom
Conversation
This is either a torch.device or a string like "cpu", but it was typed as just `Optional[str]`. This fixes it to be `Optional[Union[str, torch.device]]` and all of the downstream places that need to be updated. Found while working on TransformerLensOrg#1219
| def get_device() -> torch.device: | ||
| if torch.cuda.is_available(): | ||
| return torch.device("cuda") | ||
| if torch.backends.mps.is_available() and torch.backends.mps.is_built(): | ||
| major_version = int(torch.__version__.split(".")[0]) | ||
| if major_version >= 2: | ||
| # Auto-select MPS if PyTorch is at or above the known-safe version | ||
| if ( | ||
| _MPS_MIN_SAFE_TORCH_VERSION is not None | ||
| and _torch_version_tuple() >= _MPS_MIN_SAFE_TORCH_VERSION | ||
| ): | ||
| return torch.device("mps") | ||
| if os.environ.get("TRANSFORMERLENS_ALLOW_MPS", "") == "1": | ||
| return torch.device("mps") | ||
| logging.info( | ||
| "MPS device available but not auto-selected due to known correctness issues " | ||
| "(PyTorch %s). Set TRANSFORMERLENS_ALLOW_MPS=1 to override. See: " | ||
| "https://github.com/TransformerLensOrg/TransformerLens/issues/1178", | ||
| torch.__version__, | ||
| ) | ||
|
|
||
| return torch.device("cpu") |
There was a problem hiding this comment.
I'm actually unsure if we should change the type, or if we should change this function to just return the strings "cuda" or "cpu".
There was a problem hiding this comment.
I am of the same mind, I lean towards updating the function to return strings rather than the torch.device object.
After looking through the code this morning it appears a majority of the code base is correctly expecting strings, but there are a handful of locations that rely on the torch.device implementation. I need to determine if this is a bug or an intentional choice by past maintainers that went undocumented. I'll let you know when I have a plan for the direction.
There was a problem hiding this comment.
Ok, after some additional research and review of the code, we should correct this to always be string
Description
This is either a torch.device or a string like "cpu", but it was typed as just
Optional[str]. This fixes it to beOptional[Union[str, torch.device]]and all of the downstream places that need to be updated.Found while working on #1219.
Type of change
Please delete options that are not relevant.
Checklist: