diff --git a/medcat-trainer/webapp/api/api/views.py b/medcat-trainer/webapp/api/api/views.py index 5bb33b0fd..ed287da1c 100644 --- a/medcat-trainer/webapp/api/api/views.py +++ b/medcat-trainer/webapp/api/api/views.py @@ -1,8 +1,9 @@ import logging import os from smtplib import SMTPException -from tempfile import NamedTemporaryFile +from tempfile import NamedTemporaryFile, TemporaryDirectory from typing import Any +import shutil from background_task.models import Task, CompletedTask from django.contrib.auth.views import PasswordResetView @@ -579,11 +580,39 @@ def save_models(request): project = ProjectAnnotateEntities.objects.get(id=p_id) cat = get_medcat(project=project) - cat.cdb.save(project.concept_db.cdb_file.path) + if project.concept_db is not None: + # CDB / vocab based + cat.cdb.save(project.concept_db.cdb_file.path, overwrite=True) + else: + _overwrite_model_pack(cat, project.model_pack.path) return Response({'message': 'Models saved'}) +def _overwrite_model_pack(cat, model_path: str): + # NOTE: cannot overwrite, so working around + with TemporaryDirectory() as tmp_dir: + # making new folder name so that it's copied + # to the specific path rather than into the folder + temp_folder = os.path.join(tmp_dir, "model_copy") + shutil.move(model_path, temp_folder) + try: + cat.save_model_pack( + os.path.dirname(model_path), + pack_name=os.path.basename(model_path), + add_hash_to_pack_name=False) + except Exception as e: + logger.warning("Unable to save model pack. Restoring previous state") + if os.path.exists(model_path): + shutil.rmtree(model_path) # remove partial/corrupt output + # restore original + try: + shutil.move(temp_folder, model_path) + except Exception as restore_err: + logger.error("Failed to restore model pack:", exc_info=restore_err) + raise + + @api_view(http_method_names=['POST']) def get_create_entity(request): label = request.data['label']