diff --git a/.github/workflows/medcat-embedding-linker_ci.yml b/.github/workflows/medcat-embedding-linker_ci.yml index 2d23f3909..4b0e1d966 100644 --- a/.github/workflows/medcat-embedding-linker_ci.yml +++ b/.github/workflows/medcat-embedding-linker_ci.yml @@ -7,7 +7,7 @@ on: - 'medcat-embedding-linker/v*.*.*' pull_request: paths: - - 'medcat-embedding-linker/**' + - 'medcat-plugins/medcat-embedding-linker/**' - '.github/workflows/medcat-embedding-linker**' permissions: diff --git a/medcat-plugins/embedding-linker/pyproject.toml b/medcat-plugins/embedding-linker/pyproject.toml index be803b4a9..ea012489b 100644 --- a/medcat-plugins/embedding-linker/pyproject.toml +++ b/medcat-plugins/embedding-linker/pyproject.toml @@ -55,7 +55,7 @@ classifiers = [ # For an analysis of this field vs pip's requirements files see: # https://packaging.python.org/discussions/install-requires-vs-requirements/ dependencies = [ - "medcat[spacy]>=2.5", + "medcat[spacy]>=2.7", "transformers>=4.41.0,<5.0", # avoid major bump "torch>=2.4.0,<3.0", "tqdm", diff --git a/medcat-plugins/embedding-linker/src/medcat_embedding_linker/config.py b/medcat-plugins/embedding-linker/src/medcat_embedding_linker/config.py index d5665a1f3..a855c0b5a 100644 --- a/medcat-plugins/embedding-linker/src/medcat_embedding_linker/config.py +++ b/medcat-plugins/embedding-linker/src/medcat_embedding_linker/config.py @@ -1,17 +1,42 @@ from typing import Optional, Any - from medcat.config import Linking class EmbeddingLinking(Linking): """The config exclusively used for the embedding linker""" + comp_name: str = "embedding_linker" """Changing compoenent name""" filter_before_disamb: bool = False + """Training on names or CUIs. If True all names of all CUIs will be used to train. + If false only CUIs preffered (or longest names will be used to train). Training on + names is more expensive computationally (and RAM/VRAM), but can lead to better + performance.""" + train_on_names: bool = True """Filtering CUIs before disambiguation""" - train: bool = False - """The embedding linker never needs to be trained in its - current implementation.""" + training_batch_size: int = 32 + """The size of the batch to be used for training.""" + embed_per_n_batches: int = 0 + """How many batches to train on before re-embedding the all names in the context + model. This is used to control how often the context model is updated during + training.""" + use_similarity_threshold: bool = True + """Do we have a similarity threshold we care about?""" + negative_sampling_k: int = 10 + """How many negative samples to generate for each positive sample during + training.""" + negative_sampling_candidate_pool_size: int = 4096 + """When generating negative samples, sample top_n candidates to consider when + sampling. Higher numbers will make training slower but can provide varied negative + samples.""" + negative_sampling_temperature: float = 0.1 + """Temperature to use when generating negative samples in training. Lower + temperatures will make the sampling more focused on the highest scoring candidates, + while higher temperatures will make it more random. Must be > 0.""" + use_mention_attention: bool = True + """Improves performance and fun to say. Mention attention can help the model focus + on the most relevant parts of the context when making linking decisions. Will only + pool on the tokens that contain the entity mention, with no context.""" long_similarity_threshold: float = 0.0 """Used in the inference step to choose the best CUI given the link candidates. Testing shows a threshold of 0.7 increases precision @@ -26,11 +51,16 @@ class EmbeddingLinking(Linking): embedding_model_name: str = "sentence-transformers/all-MiniLM-L6-v2" """Name of the embedding model. It must be downloadable from huggingface linked from an appropriate file directory""" + use_projection_layer: bool = True + """Projection-layer default for trainable embedding linker.""" + top_n_layers_to_unfreeze: int = 0 + """LM unfreezing default for trainable embedding linker. + -1 unfreezes all LM layers, 0 freezes all LM layers, + n unfreezes the top n layers.""" max_token_length: int = 64 """Max number of tokens to be embedded from a name. If the max token length is changed then the linker will need to be created - with a new config. - """ + with a new config.""" embedding_batch_size: int = 4096 """How many pieces names can be embedded at once, useful when embedding name2info names, cui2info names""" @@ -44,5 +74,3 @@ class EmbeddingLinking(Linking): use_ner_link_candidates: bool = True """Link candidates are provided by some NER steps. This will flag if you want to trust them or not.""" - use_similarity_threshold: bool = True - """Do we have a similarity threshold we care about?""" diff --git a/medcat-plugins/embedding-linker/src/medcat_embedding_linker/embedding_linker.py b/medcat-plugins/embedding-linker/src/medcat_embedding_linker/embedding_linker.py index 6bb1633c4..f41df7051 100644 --- a/medcat-plugins/embedding-linker/src/medcat_embedding_linker/embedding_linker.py +++ b/medcat-plugins/embedding-linker/src/medcat_embedding_linker/embedding_linker.py @@ -4,33 +4,43 @@ from medcat.components.types import AbstractEntityProvidingComponent from medcat.tokenizing.tokens import MutableEntity, MutableDocument from medcat.tokenizing.tokenizers import BaseTokenizer -from typing import Optional, Iterator, Set +from medcat_embedding_linker.transformer_context_model import ContextModel +from typing import Optional, Iterator, Set, Any from medcat.vocab import Vocab from medcat.utils.postprocessing import filter_linked_annotations -from tqdm import tqdm +from medcat_embedding_linker.config import EmbeddingLinking from collections import defaultdict +from torch import Tensor import logging -import math import numpy as np - -from medcat_embedding_linker.config import EmbeddingLinking - -from torch import Tensor -from transformers import AutoTokenizer, AutoModel -import torch.nn.functional as F import torch logger = logging.getLogger(__name__) class Linker(AbstractEntityProvidingComponent): - name = "embedding_linker" + comp_name = "embedding_linker" + _MODEL_FOLDER_NAME = "embedding_model" + _STATE_FILE_NAME = "state.json" - def __init__(self, cdb: CDB, config: Config) -> None: + # default model kwargs for embedding linkers that do not require training + DEFAULT_MODEL_INIT_KWARGS = { + "use_projection_layer": False, + "top_n_layers_to_unfreeze": 0, + } + + def __init__( + self, + cdb: CDB, + config: Config, + model_init_kwargs: Optional[dict[str, Any]] = None, + ) -> None: """Initializes the embedding linker with a CDB and configuration. Args: cdb (CDB): The concept database to use. config (Config): The base config. + model_init_kwargs (Optional[dict[str, Any]]): Explicit kwargs that + override linker defaults. """ super().__init__() self.cdb = cdb @@ -38,23 +48,20 @@ def __init__(self, cdb: CDB, config: Config) -> None: if not isinstance(config.components.linking, EmbeddingLinking): raise TypeError("Linking config must be an EmbeddingLinking instance") self.cnf_l: EmbeddingLinking = config.components.linking - self.max_length = self.cnf_l.max_token_length + self.max_length: Optional[int] = self.cnf_l.max_token_length self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - self._name_keys = list(self.cdb.name2info) - self._cui_keys = list(self.cdb.cui2info) - - # these only need to be populated when called for embedding or inference - self._names_context_matrix = None - self._cui_context_matrix = None - - # used for filters and name embedding, and if the name contains a valid cui - # see: _set_filters - self._last_include_set: Optional[Set[str]] = None - self._last_exclude_set: Optional[Set[str]] = None - self._allowed_mask = None - self._name_has_allowed_cui = None + resolved_model_init_kwargs: dict[str, Any] = dict( + self.DEFAULT_MODEL_INIT_KWARGS + ) + resolved_model_init_kwargs.update(model_init_kwargs or {}) + self.context_model = ContextModel( + cdb=self.cdb, + linking_config=self.cnf_l, + separator=self.config.general.separator, + model_init_kwargs=resolved_model_init_kwargs, + ) # checking for config settings that aren't used in this linker if self.cnf_l.prefer_frequent_concepts: logger.warning( @@ -69,6 +76,26 @@ def __init__(self, cdb: CDB, config: Config) -> None: "in the embedding linker. It is currently set to " f"{self.cnf_l.prefer_primary_name}." ) + self.refresh_structure() + + def refresh_structure(self) -> None: + """Call this method after making changes to the CDB to update internal + structures. Called upon initialization, and can be called manually after + CDB modifications. This is usually required when training on data that + might have new cuis or names.""" + self._name_keys = list(self.cdb.name2info) + self._cui_keys = list(self.cdb.cui2info) + + # Clear context matrices to force re-embedding with new CDB structure + self._names_context_matrix = None + self._cui_context_matrix = None + + # used for filters and name embedding, and if the name contains a valid cui + # see: _set_filters + self._last_include_set: Optional[Set[str]] = None + self._last_exclude_set: Optional[Set[str]] = None + self._allowed_mask = None + self._name_has_allowed_cui = None self._cui_to_idx = {cui: idx for idx, cui in enumerate(self._cui_keys)} self._name_to_idx = {name: idx for idx, name in enumerate(self._name_keys)} @@ -82,107 +109,27 @@ def __init__(self, cdb: CDB, config: Config) -> None: ] self._initialize_filter_structures() - def create_embeddings(self, - embedding_model_name: Optional[str] = None, - max_length: Optional[int] = None, - ): - """Create embeddings for all names and cuis longest names in the CDB - using the chosen embedding model.""" + def create_embeddings( + self, + embedding_model_name: Optional[str] = None, + max_length: Optional[int] = None, + ) -> None: + """Create both CUI and name embeddings in CDB.""" if embedding_model_name is None: - embedding_model_name = self.cnf_l.embedding_model_name # fallback + embedding_model_name = self.cnf_l.embedding_model_name if max_length is not None and max_length != self.max_length: logger.info( - "Updating max_length from %s to %s", self.max_length, max_length + "Updating max_length from %s to %s", self.max_length, max_length ) self.max_length = max_length self.cnf_l.max_token_length = max_length - if ( - embedding_model_name == self.cnf_l.embedding_model_name - and "cui_embeddings" in self.cdb.addl_info - ): - logger.warning("Using the same model for embedding names.") - else: - self.cnf_l.embedding_model_name = embedding_model_name - self._load_transformers(embedding_model_name) - self._embed_cui_names(embedding_model_name) - self._embed_names(embedding_model_name) - def _embed_cui_names( - self, - embedding_model_name: str, - ) -> None: - """Obtain embeddings for all cuis longest names in the CDB using the specified - embedding model and store them in the name2info.context_vectors - Args: - embedding_model_name (str): The name of the embedding model to use. - batch_size (int): The size of the batches to use when embedding names. - Default 4096 - """ - if ( - embedding_model_name == self.cnf_l.embedding_model_name - and "cui_embeddings" in self.cdb.addl_info - and "name_embeddings" in self.cdb.addl_info - ): - logger.warning("Using the same model for embedding.") - else: - self.cnf_l.embedding_model_name = embedding_model_name + self.context_model.embed_cuis(embedding_model_name) + self.context_model.embed_names(embedding_model_name) - # Use the longest name - cui_names = [ - max(self.cdb.cui2info[cui]["names"], key=len) for cui in self._cui_keys - ] - # embed each name in batches. Because there can be 3+ million names - total_batches = math.ceil(len(cui_names) / self.cnf_l.embedding_batch_size) - all_embeddings = [] - for names in tqdm( - self._batch_data(cui_names, self.cnf_l.embedding_batch_size), - total=total_batches, - desc="Embedding cuis' preferred names", - ): - with torch.no_grad(): - # removing ~ from names, as it is used to indicate a space in the CDB - names_to_embed = [ - name.replace(self.config.general.separator, " ") for name in names - ] - embeddings = self._embed(names_to_embed, self.device) - all_embeddings.append(embeddings.cpu()) - # cat all batches into one tensor - all_embeddings_matrix = torch.cat(all_embeddings, dim=0) - self.cdb.addl_info["cui_embeddings"] = all_embeddings_matrix - logger.debug("Embedding cui names done, total: %d", len(names)) - - def _embed_names(self, embedding_model_name: str) -> None: - """Obtain embeddings for all names in the CDB using the specified - embedding model and store them in the name2info.context_vectors - Args: - embedding_model_name (str): The name of the embedding model to use. - batch_size (int): The size of the batches to use when embedding names - Default 4096 - """ - if embedding_model_name == self.cnf_l.embedding_model_name: - logger.debug("Using the same model for embedding names.") - else: - self.cnf_l.embedding_model_name = embedding_model_name - names = self._name_keys - # embed each name in batches. Because there can be 3+ million names - total_batches = math.ceil(len(names) / self.cnf_l.embedding_batch_size) - all_embeddings = [] - for names in tqdm( - self._batch_data(names, self.cnf_l.embedding_batch_size), - total=total_batches, - desc="Embedding names", - ): - with torch.no_grad(): - # removing ~ from names, as it is used to indicate a space in the CDB - names_to_embed = [ - name.replace(self.config.general.separator, " ") for name in names - ] - embeddings = self._embed(names_to_embed, self.device) - all_embeddings.append(embeddings.cpu()) - all_embeddings_matrix = torch.cat(all_embeddings, dim=0) - self.cdb.addl_info["name_embeddings"] = all_embeddings_matrix - logger.debug("Embedding names done, total: %d", len(names)) + self._names_context_matrix = None + self._cui_context_matrix = None def get_type(self) -> CoreComponentType: return CoreComponentType.linking @@ -191,51 +138,9 @@ def _batch_data(self, data, batch_size=512) -> Iterator[list]: for i in range(0, len(data), batch_size): yield data[i : i + batch_size] - def _load_transformers(self, embedding_model_name: str) -> None: - """Load the transformers model and tokenizer. - No need to load a transformer model until it's required. - Args: - embedding_model_name (str): The name of the embedding model to load. - Default is "sentence-transformers/all-MiniLM-L6-v2" - """ - if ( - not hasattr(self, "model") - or not hasattr(self, "tokenizer") - or embedding_model_name != self.cnf_l.embedding_model_name - ): - self.cnf_l.embedding_model_name = embedding_model_name - self.tokenizer = AutoTokenizer.from_pretrained(embedding_model_name) - self.model = AutoModel.from_pretrained(embedding_model_name) - self.model.eval() - gpu_device = self.cnf_l.gpu_device - self.device = torch.device( - gpu_device or ("cuda" if torch.cuda.is_available() else "cpu") - ) - self.model.to(self.device) - logger.debug( - f"""Loaded embedding model: {embedding_model_name} - on device: {self.device}""" - ) - - def _embed(self, to_embed: list[str], device) -> Tensor: - """Embeds a list of strings""" - batch_dict = self.tokenizer( - to_embed, - max_length=self.max_length, - padding=True, - truncation=True, - return_tensors="pt", - ).to(device) - outputs = self.model(**batch_dict) - outputs = self._last_token_pool( - outputs.last_hidden_state, batch_dict["attention_mask"] - ) - outputs = F.normalize(outputs, p=2, dim=1) - return outputs.half() - def _get_context( self, entity: MutableEntity, doc: MutableDocument, size: int - ) -> str: + ) -> tuple[str, tuple[int, int]]: """Get context tokens for an entity Args: @@ -244,25 +149,40 @@ def _get_context( size (int): The size of the entity. Returns: - tuple[list[BaseToken], list[BaseToken], list[BaseToken]]: - The tokens on the left, centre, and right. + tuple[str, tuple[int, int]]: + The context text and the span of the entity within that text. """ - start_ind = entity.base.start_index - end_ind = entity.base.end_index + # Token indices of the entity + start_token_idx = entity.base.start_index + end_token_idx = entity.base.end_index + + # Define token window + left_token_idx = max(0, start_token_idx - size) + right_token_idx = min(len(doc) - 1, end_token_idx + size) + + # Convert tokens → character offsets + left_most_token = doc[left_token_idx] + right_most_token = doc[right_token_idx] + + # For mention masking + snippet_start_char = left_most_token.base.char_index + snippet_end_char = right_most_token.base.char_index + len( + right_most_token.base.text + ) - left_most_token = doc[max(0, start_ind - size)] - left_index = left_most_token.base.char_index + # Slice raw document text + snippet = doc.base.text[snippet_start_char:snippet_end_char] - right_most_token = doc[min(len(doc) - 1, end_ind + size)] - right_index = right_most_token.base.char_index + len(right_most_token.base.text) + # Compute entity span relative to snippet + mention_start = entity.base.start_char_index - snippet_start_char + mention_end = entity.base.end_char_index - snippet_start_char - snippet = doc.base.text[left_index:right_index] - return snippet + return snippet, (mention_start, mention_end) def _get_context_vectors( self, doc: MutableDocument, entities: list[MutableEntity], size: int ) -> Tensor: - """Get context vectors for all detected concepts based on their + """Get context vectors for all detected concepts based on their surrounding text. Args: @@ -272,40 +192,41 @@ def _get_context_vectors( tuple[list[BaseToken], list[BaseToken], list[BaseToken]]: The tokens on the left, centre, and right.""" texts = [] + mention_spans = [] for entity in entities: - text = self._get_context(entity, doc, size) + text, span = self._get_context(entity, doc, size) texts.append(text) - return self._embed(texts, self.device) + mention_spans.append(span) + return self.context_model.embed(texts, mention_spans, self.device) def _initialize_filter_structures(self) -> None: """Call once during initialization to create efficient lookup structures.""" # Build an inverted index: cui_idx -> list of name indices that contain it # This is the KEY optimization - we flip the lookup direction - if not hasattr(self, '_cui_idx_to_name_idxs'): - cui2name_indices: defaultdict[ - int, list[int]] = defaultdict(list) - - for name_idx, cui_idxs in enumerate(self._name_to_cui_idxs): - for cui_idx in cui_idxs: - cui2name_indices[cui_idx].append(name_idx) - - # Convert lists to numpy arrays for faster indexing - self._cui_idx_to_name_idxs = { - cui_idx: np.array(name_idxs, dtype=np.int32) - for cui_idx, name_idxs in cui2name_indices.items() - } - - # Cache _has_cuis_all - if not hasattr(self, '_has_cuis_all_cached'): - self._has_cuis_all_cached = torch.tensor( - [bool(self.cdb.name2info[name]["per_cui_status"]) - for name in self._name_keys], - device=self.device, - dtype=torch.bool, - ) + cui2name_indices: defaultdict[int, list[int]] = defaultdict(list) - def _get_include_filters_1cui( - self, cui: str, n: int) -> torch.Tensor: + for name_idx, cui_idxs in enumerate(self._name_to_cui_idxs): + for cui_idx in cui_idxs: + cui2name_indices[cui_idx].append(name_idx) + + # Convert lists to numpy arrays for faster indexing + self._cui_idx_to_name_idxs = { + cui_idx: np.array(name_idxs, dtype=np.int32) + for cui_idx, name_idxs in cui2name_indices.items() + } + + # This used to be checked to be cached. + # But whenever it is called it is needed. + self._has_cuis_all_cached = torch.tensor( + [ + bool(self.cdb.name2info[name]["per_cui_status"]) + for name in self._name_keys + ], + device=self.device, + dtype=torch.bool, + ) + + def _get_include_filters_1cui(self, cui: str, n: int) -> torch.Tensor: """Optimized single CUI include filter using inverted index.""" if cui not in self._cui_to_idx: return torch.zeros(n, dtype=torch.bool, device=self.device) @@ -324,11 +245,11 @@ def _get_include_filters_1cui( return torch.zeros(n, dtype=torch.bool, device=self.device) def _get_include_filters_multi_cui( - self, include_set: Set[str], n: int) -> torch.Tensor: + self, include_set: Set[str], n: int + ) -> torch.Tensor: """Optimized multi-CUI include filter using inverted index.""" include_cui_idxs = [ - self._cui_to_idx[cui] for cui in include_set - if cui in self._cui_to_idx + self._cui_to_idx[cui] for cui in include_set if cui in self._cui_to_idx ] if not include_cui_idxs: @@ -338,33 +259,30 @@ def _get_include_filters_multi_cui( all_name_indices_list: list[np.ndarray] = [] for cui_idx in include_cui_idxs: if cui_idx in self._cui_idx_to_name_idxs: - all_name_indices_list.append( - self._cui_idx_to_name_idxs[cui_idx]) + all_name_indices_list.append(self._cui_idx_to_name_idxs[cui_idx]) if not all_name_indices_list: return torch.zeros(n, dtype=torch.bool, device=self.device) # Concatenate and get unique indices - all_name_indices = np.unique( - np.concatenate(all_name_indices_list)) + all_name_indices = np.unique(np.concatenate(all_name_indices_list)) # Create mask allowed_mask = torch.zeros(n, dtype=torch.bool, device=self.device) allowed_mask[torch.from_numpy(all_name_indices).to(self.device)] = True return allowed_mask - def _get_include_filters( - self, include_set: Set[str], n: int) -> torch.Tensor: + def _get_include_filters(self, include_set: Set[str], n: int) -> torch.Tensor: """Route to appropriate include filter method.""" if len(include_set) == 1: cui = next(iter(include_set)) return self._get_include_filters_1cui(cui, n) else: - return self._get_include_filters_multi_cui( - include_set, n) + return self._get_include_filters_multi_cui(include_set, n) def _get_exclude_filters_1cui( - self, allowed_mask: torch.Tensor, cui: str) -> torch.Tensor: + self, allowed_mask: torch.Tensor, cui: str + ) -> torch.Tensor: """Optimized single CUI exclude filter using inverted index.""" if cui not in self._cui_to_idx: return allowed_mask @@ -374,18 +292,18 @@ def _get_exclude_filters_1cui( if cui_idx in self._cui_idx_to_name_idxs: name_indices = self._cui_idx_to_name_idxs[cui_idx] # Set specific indices to False - allowed_mask[ - torch.from_numpy(name_indices).to(self.device)] = False + allowed_mask[torch.from_numpy(name_indices).to(self.device)] = False return allowed_mask def _get_exclude_filters_multi_cui( - self, allowed_mask: torch.Tensor, exclude_set: Set[str], - ) -> torch.Tensor: + self, + allowed_mask: torch.Tensor, + exclude_set: Set[str], + ) -> torch.Tensor: """Optimized multi-CUI exclude filter using inverted index.""" exclude_cui_idxs = [ - self._cui_to_idx[cui] for cui in exclude_set - if cui in self._cui_to_idx + self._cui_to_idx[cui] for cui in exclude_set if cui in self._cui_to_idx ] if not exclude_cui_idxs: @@ -403,8 +321,7 @@ def _get_exclude_filters_multi_cui( return allowed_mask - def _get_exclude_filters( - self, exclude_set: Set[str], n: int) -> torch.Tensor: + def _get_exclude_filters(self, exclude_set: Set[str], n: int) -> torch.Tensor: """Route to appropriate exclude filter method.""" # Start with all allowed allowed_mask = torch.ones(n, dtype=torch.bool, device=self.device) @@ -414,11 +331,9 @@ def _get_exclude_filters( if len(exclude_set) == 1: cui = next(iter(exclude_set)) - return self._get_exclude_filters_1cui( - allowed_mask, cui) + return self._get_exclude_filters_1cui(allowed_mask, cui) else: - return self._get_exclude_filters_multi_cui( - allowed_mask, exclude_set) + return self._get_exclude_filters_multi_cui(allowed_mask, exclude_set) def _set_filters(self) -> None: include_set = self.cnf_l.filters.cuis @@ -436,11 +351,9 @@ def _set_filters(self) -> None: n = len(self._name_keys) if include_set: - allowed_mask = self._get_include_filters( - include_set, n) + allowed_mask = self._get_include_filters(include_set, n) else: - allowed_mask = self._get_exclude_filters( - exclude_set, n) + allowed_mask = self._get_exclude_filters(exclude_set, n) self._valid_names = self._has_cuis_all_cached & allowed_mask self._last_include_set = set(include_set) if include_set is not None else None @@ -459,7 +372,9 @@ def _disambiguate_by_cui( tuple[str, float]: The CUI and its similarity """ - cui_idxs = [self._cui_to_idx[cui] for cui in cui_candidates] + cui_idxs = [ + self._cui_to_idx[cui] for cui in cui_candidates if cui in self._cui_to_idx + ] candidate_scores = scores[cui_idxs] candidate_idx = int(torch.argmax(candidate_scores).item()) best_idx = cui_idxs[candidate_idx] @@ -499,16 +414,36 @@ def _inference( if len(link_candidates) == 1: best_idx = self._cui_to_idx[link_candidates[0]] predicted_cui = link_candidates[0] - similarity = names_scores[i, best_idx].item() + if best_idx < 0 or best_idx >= cui_scores.shape[1]: + logger.warning( + "Skipping entity '%s': single-candidate index %s is out of " + "bounds for cui_scores width %s.", + entity.detected_name, + best_idx, + cui_scores.shape[1], + ) + continue + similarity = cui_scores[i, best_idx].item() elif len(link_candidates) > 1: name_to_cuis = defaultdict(list) for cui in link_candidates: for name in self.cdb.cui2info[cui]["names"]: name_to_cuis[name].append(cui) - name_idxs = [self._name_to_idx[name] for name in name_to_cuis] + name_idxs = [ + self._name_to_idx[name] + for name in name_to_cuis + if name in self._name_to_idx + ] + if name_idxs == []: + logger.warning( + "No valid name indices for entity '%s' link candidates. " + "Likely stale linker structure after CDB mutation; call " + "refresh_structure() and recreate embeddings.", + entity.detected_name, + ) + continue indexed_scores = names_scores[i, name_idxs] - best_local_pos = int(torch.argmax(indexed_scores).item()) best_global_idx = name_idxs[best_local_pos] similarity = names_scores[i, best_global_idx].item() @@ -537,9 +472,7 @@ def _inference( predicted_cui, _ = self._disambiguate_by_cui(cuis, cui_scores[i, :]) if not self.cnf_l.filters.check_filters(predicted_cui): continue - if self._check_similarity( - similarity - ): + if self._check_similarity(similarity): entity.cui = predicted_cui entity.context_similarity = similarity yield entity @@ -551,20 +484,6 @@ def _check_similarity(self, context_similarity: float) -> bool: else: return True - def _last_token_pool( - self, last_hidden_states: Tensor, attention_mask: Tensor - ) -> Tensor: - left_padding = attention_mask[:, -1].sum() == attention_mask.shape[0] - if left_padding: - return last_hidden_states[:, -1] - else: - sequence_lengths = attention_mask.sum(dim=1) - 1 - batch_size = last_hidden_states.shape[0] - return last_hidden_states[ - torch.arange(batch_size, device=last_hidden_states.device), - sequence_lengths, - ] - def _build_context_matrices(self) -> None: if "name_embeddings" in self.cdb.addl_info: self._names_context_matrix = ( @@ -613,7 +532,8 @@ def _generate_link_candidates( entity.link_candidates = list(cuis) - def _pre_inference(self, doc: MutableDocument + def _pre_inference( + self, doc: MutableDocument ) -> tuple[list[MutableEntity], list[MutableEntity]]: """Checking all entities for entites with only a single link candidate and to avoid full inference step. If we want to calculate similarities, or not use @@ -654,23 +574,14 @@ def _pre_inference(self, doc: MutableDocument to_infer.append(entity) return le, to_infer - def predict_entities(self, doc: MutableDocument, - ents: list[MutableEntity] | None = None - ) -> list[MutableEntity]: - if self.cdb.is_dirty: - logging.warning( - "CDB has been modified since last save/load. " - "This might significantly affect linking performance." - ) - logging.warning( - "If you have added new concepts or changes, " - "please re-embed the CDB names and cuis before linking." - ) - - self._load_transformers(self.cnf_l.embedding_model_name) - if self.cnf_l.train: + def predict_entities( + self, doc: MutableDocument, ents: list[MutableEntity] | None = None + ) -> list[MutableEntity]: + if self.cnf_l.train and self.comp_name == "embedding_linker": logger.warning( - "Attemping to train an embedding linker. This is not required." + "Attemping to train a static embedding linker. " + "This is not possible / required." + "Use the `trainable_embedding_linker` instead." ) if self.cnf_l.filters.cuis and self.cnf_l.filters.cuis_exclude: logger.warning( diff --git a/medcat-plugins/embedding-linker/src/medcat_embedding_linker/registration.py b/medcat-plugins/embedding-linker/src/medcat_embedding_linker/registration.py index dc4fac3cb..c34377af2 100644 --- a/medcat-plugins/embedding-linker/src/medcat_embedding_linker/registration.py +++ b/medcat-plugins/embedding-linker/src/medcat_embedding_linker/registration.py @@ -1,4 +1,3 @@ - import logging from medcat.components.types import CoreComponentType @@ -10,5 +9,14 @@ def do_registration(): lazy_register_core_component( - CoreComponentType.linking, "embedding_linker", - "medcat_embedding_linker.embedding_linker", "Linker.create_new_component") + CoreComponentType.linking, + "embedding_linker", + "medcat_embedding_linker.embedding_linker", + "Linker.create_new_component", + ) + lazy_register_core_component( + CoreComponentType.linking, + "trainable_embedding_linker", + "medcat_embedding_linker.trainable_embedding_linker", + "Linker.create_new_component", + ) diff --git a/medcat-plugins/embedding-linker/src/medcat_embedding_linker/trainable_embedding_linker.py b/medcat-plugins/embedding-linker/src/medcat_embedding_linker/trainable_embedding_linker.py new file mode 100644 index 000000000..7efe4ead2 --- /dev/null +++ b/medcat-plugins/embedding-linker/src/medcat_embedding_linker/trainable_embedding_linker.py @@ -0,0 +1,401 @@ +from typing import Optional, Union +from medcat_embedding_linker.config import EmbeddingLinking +from torch import Tensor +from medcat.cdb import CDB +from medcat.config.config import Config, ComponentConfig +from medcat.components.linking.vector_context_model import PerDocumentTokenCache +from medcat.tokenizing.tokenizers import BaseTokenizer +from medcat.tokenizing.tokens import MutableDocument, MutableEntity +from medcat.vocab import Vocab +from medcat_embedding_linker.embedding_linker import Linker +from medcat.storage.serialisables import AbstractManualSerialisable +import logging +import torch +import os +import random + +logger = logging.getLogger(__name__) + + +class TrainableEmbeddingLinker(Linker, AbstractManualSerialisable): + """Trainable variant of the embedding linker. + This class inherits inference and embedding behavior from Linker and provides + method hooks for online/offline training. + """ + + comp_name = "trainable_embedding_linker" + _MODEL_FOLDER_NAME = "trainable_embedding_model" + _MODEL_STATE_FILE_NAME = "model_state.pt" + + def __init__(self, cdb: CDB, config: Config) -> None: + if not isinstance(config.components.linking, EmbeddingLinking): + raise TypeError("Linking config must be an EmbeddingLinking instance") + self.cnf_l: EmbeddingLinking = config.components.linking + # these by default are True, and 0 + # so a projection layer is used, but only the projection is trained + model_init_kwargs = { + "use_projection_layer": self.cnf_l.use_projection_layer, + "top_n_layers_to_unfreeze": self.cnf_l.top_n_layers_to_unfreeze, + } + super().__init__( + cdb, + config, + model_init_kwargs=model_init_kwargs, + ) + self.training_batch: list[tuple] = [] + self.number_of_batches = 0 + self.negative_sampling_candidate_pool_size = ( + self.cnf_l.negative_sampling_candidate_pool_size + ) + self.scaler = torch.amp.GradScaler() # for FP16 training stability + self.optimizer = torch.optim.AdamW( + self.context_model.model.parameters(), lr=1e-4, weight_decay=0.01 + ) + + def _generate_negative_samples( + self, + candidate_indices: Tensor, + candidate_scores: Tensor, + positive_target_idxs_per_row: list[list[int]], + ) -> Tensor: + """Sample negative target indices for each entity in a batch. + + Args: + candidate_indices (Tensor): Candidate target indices, shape + ``[batch, num_candidates]``. + candidate_scores (Tensor): Scores for candidate targets aligned with + ``candidate_indices``, shape ``[batch, num_candidates]``. + positive_target_idxs_per_row (list[list[int]]): Per-row target indices that + must be excluded from negatives. + + Returns: + Tensor: Sampled negative name indices with shape ``[batch, k]`` (or + ``[k]`` for single-item input). + """ + k = self.cnf_l.negative_sampling_k + temperature = self.cnf_l.negative_sampling_temperature + + # Exclude positives from sampling by masking their scores. + positive_mask = torch.zeros_like(candidate_indices, dtype=torch.bool) + for row_idx, row_positive_idxs in enumerate(positive_target_idxs_per_row): + if not row_positive_idxs: + continue + row_positive_tensor = torch.tensor( + row_positive_idxs, + device=candidate_indices.device, + dtype=candidate_indices.dtype, + ) + positive_mask[row_idx] = torch.isin( + candidate_indices[row_idx], + row_positive_tensor, + ) + candidate_scores = candidate_scores.masked_fill(positive_mask, float("-inf")) + + probs = torch.softmax(candidate_scores / temperature, dim=1) + probs = torch.nan_to_num(probs, nan=0.0, posinf=0.0, neginf=0.0) + + max_samples = min(k, candidate_indices.size(1)) + valid_counts = (~positive_mask).sum(dim=1) + sample_count = min(max_samples, int(valid_counts.min().item())) + if sample_count <= 0: + return candidate_indices.new_empty((candidate_indices.size(0), 0)) + + sampled_positions = torch.multinomial( + probs, + num_samples=sample_count, + replacement=False, + ) + negative_indices = torch.gather( + candidate_indices, dim=1, index=sampled_positions + ) + return negative_indices + + def _build_batch_context_inputs( + self, batch: list[tuple] + ) -> tuple[list[str], list[tuple[int, int]]]: + """Convert a batch into model-ready text snippets and mention spans.""" + texts: list[str] = [] + mention_spans: list[tuple[int, int]] = [] + for doc, entity, *_ in batch: + snippet, mention_span = self._get_context( + entity, + doc, + self.cnf_l.context_window_size, + ) + texts.append(snippet) + mention_spans.append(mention_span) + return texts, mention_spans + + def _train_on_batch_targets( + self, + target_matrix: Tensor, + positive_target_idxs: list[int], + all_positive_target_idxs_per_row: list[list[int]], + ) -> None: + """Shared contrastive training path for both name and CUI targets.""" + if self.training_batch == []: + return + + texts, mention_spans = self._build_batch_context_inputs(self.training_batch) + + self.optimizer.zero_grad() + with torch.amp.autocast( + device_type=str(self.device) + ): # controls FP16 usage for better stability + # Forward pass to get context vectors for each entity in the batch. + self.context_model.model.train() + + context_vectors = self.context_model.embed( + texts, + mention_spans=mention_spans, + ) # [batch, dim] + + # Target embeddings are fixed; no gradient flows through them. + target_matrix = target_matrix.detach() # [num_targets, dim] + + # Negative sampling does not need gradients. + with torch.no_grad(): + target_scores = context_vectors.detach() @ target_matrix.T + candidate_pool_size = min( + target_scores.size(1), self.negative_sampling_candidate_pool_size + ) + + candidate_scores, candidate_indices = torch.topk( + target_scores, + k=candidate_pool_size, + dim=1, + largest=True, + sorted=False, + ) + negative_indices = self._generate_negative_samples( + candidate_indices, + candidate_scores, + all_positive_target_idxs_per_row, + ) + + pos_idx_tensor = torch.tensor(positive_target_idxs, device=self.device) + positive_embeds = target_matrix[pos_idx_tensor] + negative_embeds = target_matrix[negative_indices] + + positive_scores = (context_vectors * positive_embeds).sum( + dim=1, keepdim=True + ) + negative_scores = torch.bmm( + negative_embeds, context_vectors.unsqueeze(-1) + ).squeeze(-1) + + logits = torch.cat([positive_scores, negative_scores], dim=1) + # The target is always the first position (the positive sample). + # So these are idx's of 0 in the logits tensor, not target indices. + targets = torch.zeros( + len(self.training_batch), dtype=torch.long, device=self.device + ) + + loss = torch.nn.functional.cross_entropy(logits, targets) + + self.scaler.scale(loss).backward() + self.scaler.step(self.optimizer) + self.scaler.update() + + self.context_model.model.eval() + logger.debug("Training batch loss: %.4f", loss.item()) + + def _train_on_batch_cuis(self) -> None: + """Train on a batch of CUI-based tuples. + + Runs a contrastive forward pass through the context encoder, computes + cross-entropy loss over one positive and k negative CUI embeddings, and + performs a single optimizer step. + """ + if self.training_batch == []: + return + + positive_cui_idxs = [sample[2] for sample in self.training_batch] + if len(self.training_batch[0]) >= 4: + all_positive_cui_idxs_per_row = [ + sample[3] for sample in self.training_batch + ] + else: + all_positive_cui_idxs_per_row = [[pos_idx] for pos_idx in positive_cui_idxs] + + self._train_on_batch_targets( + self.cui_context_matrix, + positive_cui_idxs, + all_positive_cui_idxs_per_row, + ) + + def _train_on_batch_names(self) -> None: + """Train on a batch of + (doc, entity, positive_name_idx, all_positive_name_idxs) tuples.""" + if self.training_batch == []: + return + + positive_name_idxs = [sample[2] for sample in self.training_batch] + all_positive_name_idxs_per_row = [sample[3] for sample in self.training_batch] + + self._train_on_batch_targets( + self.names_context_matrix, + positive_name_idxs, + all_positive_name_idxs_per_row, + ) + + def _train_on_batch(self) -> None: + """Train on the current batch, dispatching to names or CUI mode. + + This should also be called manually at the end of training to flush any + remaining samples that didn't fill a batch. + + Args: + training_batch (list[tuple]): + Name mode: (doc, entity, positive_name_idx, all_positive_name_idxs) + CUI mode: (doc, entity, positive_cui_idx) + """ + if self.training_batch == []: + return + + tuple_lengths = {len(sample) for sample in self.training_batch} + if len(tuple_lengths) != 1: + raise ValueError( + "Mixed training batch formats detected. " + "Expected uniform tuples for names (len=4) or CUIs (len=3)." + ) + + sample_len = tuple_lengths.pop() + if sample_len == 4: + # A len=4 tuple is interpreted as name mode by default. + self._train_on_batch_names() + return + if sample_len == 3: + self._train_on_batch_cuis() + return + + raise ValueError( + f"Unsupported training batch tuple size: {sample_len}. " + "Expected len=3 for CUIs or len=4 for names." + ) + + def train( + self, + cui: str, + entity: MutableEntity, + doc: MutableDocument, + negative: bool = False, + names: Union[list[str], dict] = [], + per_doc_valid_token_cache: Optional[PerDocumentTokenCache] = None, + ) -> None: + """Train the linker. + + This simply trains the context model. + + This will collect samples to train in batches. Once a batch is ready, the + forward pass will be done and gradients will be collected. + + Args: + cui (str): The ground truth label for the entity. + entity (BaseEntity): The entity we're at. + doc (BaseDocument): The document within which we're working. + negative (bool): To be ignored here. + names (list[str]/dict): + Unused within the embedding linker, but required for the interface. + Used to provide the names of the concept for which we're training. + per_doc_valid_token_cache (PerDocumentTokenCache): + Unused within the embedding linker, but required for the interface. + """ + if negative: + logger.warning( + "Negative samples are not currently used in training the " \ + "embedding linker. Skipping." + ) + return + if self.cnf_l.train_on_names: + # Name mode: sample one positive name and keep all positive aliases + # for this CUI so aliases can be excluded from negatives. + positive_samples = self.cdb.cui2info[cui]["names"] + all_positive_name_idxs: list[int] = [] + for pos_sample in positive_samples: + pos_idx = self._name_to_idx.get(pos_sample) + if pos_idx is not None: + all_positive_name_idxs.append(pos_idx) + if not all_positive_name_idxs: + return + pos_idx = random.choice(all_positive_name_idxs) + self.training_batch.append((doc, entity, pos_idx, all_positive_name_idxs)) + else: + # CUI mode: one positive CUI index per row. + positive_cui_idx = self._cui_to_idx.get(cui) + if positive_cui_idx is None: + return + self.training_batch.append((doc, entity, positive_cui_idx)) + if ( + len(self.training_batch) >= self.cnf_l.training_batch_size + or entity is doc.ner_ents[-1] + ): + logger.debug( + "End of document reached; training on final batch of size %s", + len(self.training_batch), + ) + self._train_on_batch() + self.training_batch = [] + self.number_of_batches += 1 + # If you've got as many batches as you want before re-embedding, + # then do it and reset the counter. + if ( + self.cnf_l.embed_per_n_batches > 0 + and self.number_of_batches > self.cnf_l.embed_per_n_batches + ): + logger.debug( + "Re-embedding names and CUIs after training on %s batches.", + self.number_of_batches, + ) + self.refresh_structure() + # Always refresh both embeddings to keep CDB and embeddings in sync. + # Inference always uses both names_context_matrix and cui_context_matrix. + # And inference is called during the cat.trainer.train() loop + self.context_model.embed_names() + self.context_model.embed_cuis() + self._names_context_matrix = None + self._cui_context_matrix = None + self.number_of_batches = 0 + + @classmethod + def create_new_component( + cls, + cnf: ComponentConfig, + tokenizer: BaseTokenizer, + cdb: CDB, + vocab: Vocab, + model_load_path: Optional[str], + ) -> "TrainableEmbeddingLinker": + return cls(cdb, cdb.config) + + def serialise_to(self, folder_path: str) -> None: + # Ensure final partial batch is not dropped before saving model state. + logger.info("Flushing final training batch before saving model.") + logger.info("This is grandfathered in from trainer.py restraints.") + + os.makedirs(folder_path, exist_ok=True) + model_folder = os.path.join(folder_path, self._MODEL_FOLDER_NAME) + os.makedirs(model_folder, exist_ok=True) + + torch.save( + self.context_model.model.state_dict(), + os.path.join(model_folder, self._MODEL_STATE_FILE_NAME), + ) + + @classmethod + def deserialise_from( + cls, folder_path: str, **init_kwargs + ) -> "TrainableEmbeddingLinker": + cdb = init_kwargs["cdb"] + linker = cls(cdb, cdb.config) + + model_state_path = os.path.join( + folder_path, cls._MODEL_FOLDER_NAME, cls._MODEL_STATE_FILE_NAME + ) + if os.path.exists(model_state_path): + state_dict = torch.load(model_state_path, map_location=linker.device) + linker.context_model.model.load_state_dict(state_dict) + + linker._names_context_matrix = None + linker._cui_context_matrix = None + return linker diff --git a/medcat-plugins/embedding-linker/src/medcat_embedding_linker/transformer_context_model.py b/medcat-plugins/embedding-linker/src/medcat_embedding_linker/transformer_context_model.py new file mode 100644 index 000000000..bd058b802 --- /dev/null +++ b/medcat-plugins/embedding-linker/src/medcat_embedding_linker/transformer_context_model.py @@ -0,0 +1,376 @@ +from pathlib import Path +from typing import Any, Iterator, Optional, Union +from medcat.storage.serialisables import AbstractSerialisable +from torch import Tensor, nn +from transformers import AutoModel, AutoTokenizer +from tqdm import tqdm +import json +import logging +import math +import torch +import torch.nn.functional as F + +logger = logging.getLogger(__name__) + + +class ModelForEmbeddingLinking(nn.Module): + """Wrapper around a Hugging Face transformer for embedding-based linking. + + The model applies mean pooling over token embeddings, optionally projects the + pooled vector, and L2 normalizes the final embedding. + """ + + def __init__( + self, + embedding_model_name: str, + use_projection_layer: bool = False, + top_n_layers_to_unfreeze: int = -1, + device: Optional[Union[str, torch.device]] = None, + ) -> None: + super().__init__() + self.language_model = AutoModel.from_pretrained(embedding_model_name) + self.base_model_name = self.language_model.name_or_path + + self.use_projection_layer = use_projection_layer + self.top_n_layers_to_unfreeze = top_n_layers_to_unfreeze + + hidden_size = self.language_model.config.hidden_size + if self.use_projection_layer: + self.projection_layer = nn.Linear(hidden_size, hidden_size) + + self._freeze_all_parameters() + self.unfreeze_top_n_lm_layers(self.top_n_layers_to_unfreeze) + + target_device = self._resolve_device(device) + self.to(target_device) + + @staticmethod + def _resolve_device(device: Optional[Union[str, torch.device]]) -> torch.device: + if device is None: + return torch.device("cuda" if torch.cuda.is_available() else "cpu") + return torch.device(device) + + @property + def device(self) -> torch.device: + return next(self.parameters()).device + + @staticmethod + def masked_mean_pooling(token_embeddings: Tensor, mask: Tensor) -> Tensor: + mask = mask.unsqueeze(-1).float() + summed = torch.sum(token_embeddings * mask, dim=1) + counts = torch.clamp(mask.sum(dim=1), min=1e-9) + return summed / counts + + def forward(self, **inputs) -> Tensor: + # Don't pass the mention_mask to the language model if it does exist + mention_mask = inputs.pop("mention_mask", None) + model_output = self.language_model(**inputs) + + pooling_mask = ( + mention_mask if mention_mask is not None else inputs["attention_mask"] + ) + sentence_embeddings = self.masked_mean_pooling( + model_output.last_hidden_state, pooling_mask + ) + + if self.use_projection_layer: + sentence_embeddings = self.projection_layer(sentence_embeddings) + return F.normalize(sentence_embeddings, p=2, dim=1) + + def _freeze_all_parameters(self) -> None: + for param in self.language_model.parameters(): + param.requires_grad = False + + if self.use_projection_layer: + for param in self.projection_layer.parameters(): + param.requires_grad = True + + def unfreeze_top_n_lm_layers(self, n: int) -> None: + # train all LM layers - each layer requires more data + if n == -1: + for param in self.language_model.parameters(): + param.requires_grad = True + return + + # keep LM fully frozen - better with less data + if n == 0: + return + + # BERT-likes + if hasattr(self.language_model, "encoder") and hasattr( + self.language_model.encoder, "layer" + ): + layers = self.language_model.encoder.layer + # DistilBERT-likes + elif hasattr(self.language_model, "transformer") and hasattr( + self.language_model.transformer, "layer" + ): + layers = self.language_model.transformer.layer + else: + raise ValueError("Unsupported LM architecture for layer unfreezing.") + + total_layers = len(layers) + n = min(n, total_layers) + for layer in layers[-n:]: + for param in layer.parameters(): + param.requires_grad = True + + def save_pretrained(self, save_directory: Union[str, Path]) -> None: + save_path = Path(save_directory) + save_path.mkdir(parents=True, exist_ok=True) + + torch.save(self.state_dict(), save_path / "pytorch_model.bin") + + config = { + "embedding_model_name": self.base_model_name, + "use_projection_layer": self.use_projection_layer, + "top_n_layers_to_unfreeze": self.top_n_layers_to_unfreeze, + } + with open(save_path / "config.json", "w", encoding="utf-8") as f: + json.dump(config, f, indent=2) + + @classmethod + def from_pretrained( + cls, + path_or_model_name: Union[str, Path], + device: Optional[Union[str, torch.device]] = None, + **kwargs, + ) -> "ModelForEmbeddingLinking": + path = Path(path_or_model_name) + config_path = path / "config.json" + weights_path = path / "pytorch_model.bin" + target_device = cls._resolve_device(device) + + # Local saved wrapper model. + if config_path.exists() and weights_path.exists(): + with open(config_path, encoding="utf-8") as f: + config = json.load(f) + + config.update(kwargs) + model = cls(**config) + state_dict = torch.load(weights_path, map_location="cpu") + model.load_state_dict(state_dict) + model.to(target_device) + return model + + # Hugging Face model id/path. + model = cls( + embedding_model_name=str(path_or_model_name), + device=target_device, + **kwargs, + ) + return model + + +class ContextModel(AbstractSerialisable): + """Encapsulates embedding model/tokenizer loading and CDB embedding creation.""" + + def __init__( + self, + cdb, + linking_config, + separator: str, + model_init_kwargs: Optional[dict[str, Any]] = None, + ) -> None: + self.cdb = cdb + self.cnf_l = linking_config + self.separator = separator + self._model_init_kwargs = dict(model_init_kwargs or {}) + self.max_length = self.cnf_l.max_token_length + self.device = torch.device( + self.cnf_l.gpu_device or ("cuda" if torch.cuda.is_available() else "cpu") + ) + self._refresh_cdb_keys() + self._cui_keys = list(self.cdb.cui2info) + self._loaded_model_source: Optional[str] = None + self._loaded_model_init_kwargs: Optional[dict[str, Any]] = None + self.load_transformers(self.cnf_l.embedding_model_name) + + def _batch_data(self, data, batch_size=512) -> Iterator[list]: + for i in range(0, len(data), batch_size): + yield data[i : i + batch_size] + + def _refresh_cdb_keys(self) -> None: + """Refresh key caches from current CDB state. + + This is required after in initialisation and for training-time CDB + mutations where new names/CUIs be introduced. + """ + self._name_keys = list(self.cdb.name2info) + self._cui_keys = list(self.cdb.cui2info) + + @staticmethod + def _resolve_model_source(path_or_model_name: Union[str, Path]) -> str: + """Return local absolute path if it exists, otherwise keep HF model id.""" + candidate = Path(path_or_model_name).expanduser() + if candidate.exists(): + return str(candidate.resolve()) + return str(path_or_model_name) + + def _get_model_init_kwargs(self) -> dict[str, Any]: + """Build kwargs passed to ModelForEmbeddingLinking.from_pretrained.""" + return dict(self._model_init_kwargs) + + def load_transformers(self, embedding_model_name: Union[str, Path]) -> None: + """Load tokenizer/model from local path or Hugging Face model id.""" + model_source = self._resolve_model_source(embedding_model_name) + model_init_kwargs = self._get_model_init_kwargs() + if ( + not hasattr(self, "model") + or not hasattr(self, "tokenizer") + or model_source != self._loaded_model_source + or model_init_kwargs != self._loaded_model_init_kwargs + ): + self.cnf_l.embedding_model_name = str(embedding_model_name) + self.tokenizer = AutoTokenizer.from_pretrained(model_source) + self.model = ModelForEmbeddingLinking.from_pretrained( + model_source, **model_init_kwargs + ) + self.model.eval() + self.device = torch.device( + self.cnf_l.gpu_device + or ("cuda" if torch.cuda.is_available() else "cpu") + ) + self.model.to(self.device) + self._loaded_model_source = model_source + self._loaded_model_init_kwargs = model_init_kwargs + logger.debug( + "Loaded embedding model: %s (resolved source: %s) with kwargs=%s " \ + "on device: %s", + embedding_model_name, + model_source, + model_init_kwargs, + self.device, + ) + + def _build_mention_mask_from_char_spans( + self, + batch_dict: dict[str, Tensor], + mention_char_spans: list[tuple[int, int]], + device: torch.device, + ) -> Tensor: + """ + Convert character-level mention spans into a token-level mask. + + Args: + batch_dict: tokenizer output with 'offset_mapping' + mention_char_spans: list of (start_char, end_char) per example + device: torch device + + Returns: + mask: [batch_size, seq_len] float Tensor, 1.0 for mention tokens, + 0.0 otherwise + """ + offset_mapping = batch_dict["offset_mapping"] # [B, max_token_length, 2] + batch_size, seq_len, _ = offset_mapping.shape + mask = torch.zeros((batch_size, seq_len), dtype=torch.float32, device=device) + + for i, (mention_start, mention_end) in enumerate(mention_char_spans): + # For each token in the sequence + for j in range(seq_len): + token_start, token_end = offset_mapping[i, j].tolist() + # Skip padding tokens + if token_end == 0 and token_start == 0: + continue + # Check if token overlaps mention span + if token_end > mention_start and token_start < mention_end: + mask[i, j] = 1.0 + + return mask + + def embed( + self, + to_embed: list[str], + mention_spans: Optional[list[tuple[int, int]]] = None, + device: Optional[torch.device] = None, + ) -> Tensor: + """Embed a list of input strings.""" + target_device = device or self.device + # we don't need offset mapping when just embedding potential labels + need_offsets_mapping = ( + self.cnf_l.use_mention_attention and mention_spans is not None + ) + + batch_dict = self.tokenizer( + to_embed, + max_length=self.max_length, + padding=True, + truncation=True, + return_tensors="pt", + return_offsets_mapping=need_offsets_mapping, + ).to(target_device) + + mention_mask = None + if mention_spans is not None: + mention_mask = self._build_mention_mask_from_char_spans( + batch_dict, + mention_spans, + target_device, + ) + batch_dict["mention_mask"] = mention_mask + + # Keep tokenizer-only metadata out of model forward kwargs. + batch_dict.pop("offset_mapping", None) + + outputs = self.model(**batch_dict) + return outputs.half() + + def embed_cuis( + self, embedding_model_name: Optional[Union[str, Path]] = None + ) -> None: + """Create embeddings for each CUI's longest name and store in CDB. + + If ``embedding_model_name`` is provided, switch/load that model first. + Otherwise, reuse the currently loaded model (training-friendly default). + """ + target_model = embedding_model_name or self.cnf_l.embedding_model_name + self._refresh_cdb_keys() # ensure _cui_keys is up to date before embedding + self.load_transformers(target_model) + + cui_names = [self.cdb.get_name(cui) for cui in self._cui_keys] + total_batches = math.ceil(len(cui_names) / self.cnf_l.embedding_batch_size) + all_embeddings = [] + for names in tqdm( + self._batch_data(cui_names, self.cnf_l.embedding_batch_size), + total=total_batches, + desc="Embedding cuis' preferred names", + ): + with torch.no_grad(): + names_to_embed = [name.replace(self.separator, " ") for name in names] + embeddings = self.embed(names_to_embed, device=self.device) + all_embeddings.append(embeddings.cpu()) + + all_embeddings_matrix = torch.cat(all_embeddings, dim=0) + self.cdb.addl_info["cui_embeddings"] = all_embeddings_matrix + logger.debug("Embedding cui names done, total: %d", len(cui_names)) + + def embed_names( + self, embedding_model_name: Optional[Union[str, Path]] = None + ) -> None: + """Create embeddings for all names and store in CDB. + + If ``embedding_model_name`` is provided, switch/load that model first. + Otherwise, reuse the currently loaded model (training-friendly default). + """ + target_model = embedding_model_name or self.cnf_l.embedding_model_name + self._refresh_cdb_keys() # ensure _cui_keys is up to date before embedding + self.load_transformers(target_model) + + names = self._name_keys + total_batches = math.ceil(len(names) / self.cnf_l.embedding_batch_size) + all_embeddings = [] + for batch_names in tqdm( + self._batch_data(names, self.cnf_l.embedding_batch_size), + total=total_batches, + desc="Embedding names", + ): + with torch.no_grad(): + names_to_embed = [ + name.replace(self.separator, " ") for name in batch_names + ] + embeddings = self.embed(names_to_embed, device=self.device) + all_embeddings.append(embeddings.cpu()) + + all_embeddings_matrix = torch.cat(all_embeddings, dim=0) + self.cdb.addl_info["name_embeddings"] = all_embeddings_matrix + logger.debug("Embedding names done, total: %d", len(names)) diff --git a/medcat-plugins/embedding-linker/tests/__init__.py b/medcat-plugins/embedding-linker/tests/__init__.py index 824839547..b40364e1c 100644 --- a/medcat-plugins/embedding-linker/tests/__init__.py +++ b/medcat-plugins/embedding-linker/tests/__init__.py @@ -12,13 +12,14 @@ # unpack model pack at start so we can access stuff like Vocab print("Unpacking included test model pack") -shutil.unpack_archive( - EXAMPLE_MODEL_PACK_ZIP, UNPACKED_EXAMPLE_MODEL_PACK_PATH) +shutil.unpack_archive(EXAMPLE_MODEL_PACK_ZIP, UNPACKED_EXAMPLE_MODEL_PACK_PATH) def _del_unpacked_model(): - print("Cleaning up! Removing unpacked exmaple model pack:", - UNPACKED_EXAMPLE_MODEL_PACK_PATH) + print( + "Cleaning up! Removing unpacked exmaple model pack:", + UNPACKED_EXAMPLE_MODEL_PACK_PATH, + ) shutil.rmtree(UNPACKED_EXAMPLE_MODEL_PACK_PATH) diff --git a/medcat-plugins/embedding-linker/tests/helper.py b/medcat-plugins/embedding-linker/tests/helper.py index 9513f3c43..4a70e259d 100644 --- a/medcat-plugins/embedding-linker/tests/helper.py +++ b/medcat-plugins/embedding-linker/tests/helper.py @@ -5,7 +5,6 @@ class FakeCDB: - def __init__(self, cnf: Config): self.config = cnf self.token_counts = {} @@ -26,7 +25,7 @@ class FTokenizer: class ComponentInitTests: expected_def_components = 1 - default = 'default' + default = "default" # these need to be specified when overriding comp_type: types.CoreComponentType default_cls: Type[types.BaseComponent] @@ -38,19 +37,22 @@ def setUpClass(cls): cls.fcdb = FakeCDB(cls.cnf) cls.fvocab = FVocab() cls.vtokenizer = FTokenizer() - cls.comp_cnf: ComponentConfig = getattr( - cls.cnf.components, cls.comp_type.name) + cls.comp_cnf: ComponentConfig = getattr(cls.cnf.components, cls.comp_type.name) if isinstance(cls.default_creator, Type): cls._def_creator_name_opts = (cls.default_creator.__name__,) else: # classmethod - cls._def_creator_name_opts = (".".join(( - # etiher class.method_name - cls.default_creator.__self__.__name__, - cls.default_creator.__name__)), + cls._def_creator_name_opts = ( + ".".join( + ( + # etiher class.method_name + cls.default_creator.__self__.__name__, + cls.default_creator.__name__, + ) + ), # or just method_name - cls.default_creator.__name__ - ) + cls.default_creator.__name__, + ) def test_has_default(self): avail_components = types.get_registered_components(self.comp_type) @@ -58,8 +60,9 @@ def test_has_default(self): name, cls_name = avail_components[0] # 1 name / cls name eq_name = [name == self.default for name, _ in avail_components] - eq_cls = [cls_name in self._def_creator_name_opts - for _, cls_name in avail_components] + eq_cls = [ + cls_name in self._def_creator_name_opts for _, cls_name in avail_components + ] self.assertEqual(sum(eq_name), 1) # NOTE: for NER both the default as well as the Dict based NER # have the came class name, so may be more than 1 @@ -70,7 +73,12 @@ def test_has_default(self): def test_can_create_def_component(self): component = types.create_core_component( self.comp_type, - self.default, self.cnf, self.vtokenizer, self.fcdb, self.fvocab, None) - self.assertIsInstance(component, - runtime_checkable(types.BaseComponent)) + self.default, + self.cnf, + self.vtokenizer, + self.fcdb, + self.fvocab, + None, + ) + self.assertIsInstance(component, runtime_checkable(types.BaseComponent)) self.assertIsInstance(component, self.default_cls) diff --git a/medcat-plugins/embedding-linker/tests/test_embedding_linker.py b/medcat-plugins/embedding-linker/tests/test_embedding_linker.py index e7923426f..06dedc031 100644 --- a/medcat-plugins/embedding-linker/tests/test_embedding_linker.py +++ b/medcat-plugins/embedding-linker/tests/test_embedding_linker.py @@ -1,4 +1,5 @@ from medcat_embedding_linker import embedding_linker +from medcat_embedding_linker import trainable_embedding_linker from medcat.components import types from medcat.config import Config from medcat.data.entities import Entity @@ -12,16 +13,20 @@ from . import UNPACKED_EXAMPLE_MODEL_PACK_PATH + class FakeDocument: linked_ents = [] ner_ents = [] + def __init__(self, text): self.text = text + class FakeTokenizer: def __call__(self, text: str) -> FakeDocument: return FakeDocument(text) + class FakeCDB: def __init__(self, config: Config): self.is_dirty = False @@ -37,7 +42,7 @@ def weighted_average_function(self, nr: int) -> float: class EmbeddingLinkerInitTests(ComponentInitTests, unittest.TestCase): expected_def_components = len(DEF_LINKING) comp_type = types.CoreComponentType.linking - default = 'embedding_linker' + default = "embedding_linker" default_cls = embedding_linker.Linker default_creator = embedding_linker.Linker.create_new_component module = embedding_linker @@ -57,6 +62,7 @@ def test_has_default(self): registered_names = [name for name, _ in avail_components] self.assertIn("embedding_linker", registered_names) + class NonTrainableEmbeddingLinkerTests(unittest.TestCase): cnf = Config() cnf.components.linking = embedding_linker.EmbeddingLinking() @@ -71,6 +77,18 @@ def test_linker_processes_document(self): self.linker(doc) +class TrainableEmbeddingLinkerTests(unittest.TestCase): + cnf = Config() + cnf.components.linking = embedding_linker.EmbeddingLinking() + cnf.components.linking.comp_name = ( + trainable_embedding_linker.TrainableEmbeddingLinker.name + ) + linker = trainable_embedding_linker.TrainableEmbeddingLinker(FakeCDB(cnf), cnf) + + def test_linker_is_trainable(self): + self.assertIsInstance(self.linker, TrainableComponent) + + class EmbeddingModelDisambiguationTests(unittest.TestCase): PLACEHOLDER = "{SOME_PLACEHOLDER}" TEXT = f"""The issue has a lot to do with the {PLACEHOLDER}""" @@ -81,7 +99,8 @@ def setUpClass(cls) -> None: cls.model.config.components.linking = embedding_linker.EmbeddingLinking() cls.model._recreate_pipe() linker: embedding_linker.Linker = cls.model.pipe.get_component( - types.CoreComponentType.linking) + types.CoreComponentType.linking + ) linker.create_embeddings() cls.linker = linker @@ -89,14 +108,12 @@ def test_is_correct_linker(self): self.assertIsInstance(self.linker, embedding_linker.Linker) def assert_has_name(self, out_ents: dict[int, Entity], name: str): - self.assertTrue( - any(ent["source_value"] == name for ent in out_ents.values()) - ) + self.assertTrue(any(ent["source_value"] == name for ent in out_ents.values())) def test_does_disambiguation(self): used_names = 0 for name, info in self.model.cdb.name2info.items(): - if len(info['per_cui_status']) <= 1: + if len(info["per_cui_status"]) <= 1: continue used_names += 1 with self.subTest(name): @@ -104,4 +121,3 @@ def test_does_disambiguation(self): out_ents = self.model.get_entities(cur_text)["entities"] self.assert_has_name(out_ents, name) self.assertGreater(used_names, 0) -