feat(medcat-plugins) trainable linker for the embedding linker#392
feat(medcat-plugins) trainable linker for the embedding linker#392adam-sutton-1992 wants to merge 7 commits intomainfrom
Conversation
This PR should make the workflow actually run upon changes.
mart-r
left a comment
There was a problem hiding this comment.
I must say, some of the details her go over my head.
But I think overall it looks really good!
I did leave a few comments / questions. So feel free to address what's relevant.
medcat-plugins/embedding-linker/src/medcat_embedding_linker/embedding_linker.py
Show resolved
Hide resolved
medcat-plugins/embedding-linker/src/medcat_embedding_linker/embedding_linker.py
Show resolved
Hide resolved
medcat-plugins/embedding-linker/src/medcat_embedding_linker/trainable_embedding_linker.py
Show resolved
Hide resolved
| def serialise_to(self, folder_path: str) -> None: | ||
| # Ensure final partial batch is not dropped before saving model state. | ||
| logger.info("Flushing final training batch before saving model.") | ||
| logger.info("This is grandfathered in from trainer.py restraints.") |
There was a problem hiding this comment.
I see this is the main workhorse in terms of working around the restrictions before #374.
Perhaps we can add a check for medcat version and if it's greater or equal to 2.7 then avoid the train here and the log message?
Plus, we'd need to implement the change in the train method as well - i.e if it's 2.7 or later then just train on batch upon last entity in document.
But then again doing it after every document might be problematic if the documents are short (i.e not many entities) and/or the last batch in it has very few entities. So I'm open to leaving it as is if you prefer.
There was a problem hiding this comment.
I think I have a strong preference.
I would like to remove all calls to "train_on_batch" where possible.
This would remove having to call cat._pipeline._components[-1].train_on_batch() in the training loop as such in my original PR comment:
cat.trainer.train_supervised_raw(train_projects, test_size=0, nepochs=1)
cat._pipeline._components[-1].train_on_batch()
cat._pipeline._components[-1].refresh_structure()
cat._pipeline._components[-1].create_embeddings()
I agree for cases of small documents (i.e. COMETA), it's a minor pain to be embedding / training on just a few entities. But I don't think most training will happen this way.
Preferably I would change the dependency of the embedding linker to MedCAT 2.7 (or whichever one houses these fixes), remove all of these calls, and then protect train_on_batch to _train_on_batch as originally intended.
If possible I'd like to have it so users don't have to bother with calling refresh_structure or create_embeddings either after training. But that's probably for a later date.
There was a problem hiding this comment.
Yeah, 2.7 should have this fix in. So I'm fine with supporting only that and above. I think it's unlikely people will be pairing this with old medcat since there isn't really a massive legacy userbase with old versions installed anyway.
Hihi,
Here's the PR for the trainable embedding linker as foretold by our ancestors.
I'll list the main parts of it here, then the performances.
trainable_embedding_linkerthat inherits from theembedding_linker. The only functionality in trainable version is that which is only used by it.embedding_linkerinto thetransformer_context_modelmodule. This has two classes within it:ContextModelwhich handles all embedding tasks (embed_cuis...).ModelForEmbeddingLinkinginherits from the transformersnn.moduleand is where the model logic is. It's largely a wrapper around the language model with a few choices of how to embed.configthese explain themselves, all of these are to do with training and effectively are trade offs between time / compute / performance.Onto performances:
This is trained / tested with three datasets that all contain SNOMED CT labels. The SNOMED CT Entity Benchmark, Distemist, and COMETA. Due to this requiring training I have a train/test split of 80/20. With the implementation of cat.train at the time of making this I kind of built the training loop outside of the medcat library.
Effectively it looks like this:
Here's the baseline performance without training (which would be comparing to a normal embedding_linker):
Epoch: 0, Prec: 0.081820475543968, Rec: 0.3308877119673485, F1: 0.13119870535198244
Here's the best performance with a few hyperparams I've found to be optimal:
Epoch: 0, Prec: 0.11907464089601173, Rec: 0.5048605035481676, F1: 0.19269978201382124
These hyper-params I mention are, with a bit of commentary:
cat.config.components.linking.embedding_model_name = "abhinand/MedEmbed-small-v0.1"(we default to "sentence-transformers/all-MiniLM-L6-v2" because it's kind of a standard, and it's small (6 layers). But MedEmbed in my experience is the best embedding I've found for medical stuff)context_window_size= 35 (with more context performance should increase, if you use themention_mask, the gains past 14 tokens is very marginal)cat._pipeline._components[-1].context_model.model.unfreeze_top_n_lm_layers(4)(As you increase the number of layers the performance will increase, at the cost of storing more gradients and computational complexity. You will also transform the embedding space, requiring more data. I haven't gone beyond unfreezing four layers, because that's a redundant task at this point.)lr=1e-4, weight_decay=0.01(These are hardcoded and it's TODO: it shouldn't be. Oops. If you're only training the linear projection layer you can set it as high as 1e-3. When you start affecting transformer layers it significantly impacts performance for the worse).cat._pipeline._components[-1].cnf_l.train_on_names = True(You can train on cuis or all potential names. A full CDB is about 3 mil names, or 600k CUIs. Names performs slightly better, CUIs performs quite a bit faster. It's a one time cost for training a model, so something to consider`Here's some earlier iterations performances:
Without mention_attention functionality:
Epoch: 0, Prec: 0.10870886017906067, Rec: 0.4608991494532199, F1: 0.17592386464826354It performed best with smaller context windows (size 10).
An earlier experiment: Unfreezing layers (these were done without the COMETA dataset, so just Distemist and the Snomed_CT Benchmark):
ALL LAYERS FROZEN
Epoch: 0, Prec: 0.10861838458277627, Rec: 0.40456670917592763, F1: 0.17125743885040035UNFREEZING TOP 1 LAYERS
Epoch: 0, Prec: 0.11208035088291403, Rec: 0.4174662941819507, F1: 0.1767163258223325UNFREEZING TOP 2 LAYERS
Epoch: 0, Prec: 0.11370157819225252, Rec: 0.4235394145511964, F1: 0.17927559702835402UNFREEZING TOP 4 LAYERS
Epoch: 0, Prec: 0.11482145768791782, Rec: 0.4276691364022835, F1: 0.18103758547997326TODO: