Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 31 additions & 2 deletions medcat-trainer/webapp/api/api/views.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import logging
import os
from smtplib import SMTPException
from tempfile import NamedTemporaryFile
from tempfile import NamedTemporaryFile, TemporaryDirectory
from typing import Any
import shutil

from background_task.models import Task, CompletedTask
from django.contrib.auth.views import PasswordResetView
Expand Down Expand Up @@ -579,11 +580,39 @@ def save_models(request):
project = ProjectAnnotateEntities.objects.get(id=p_id)
cat = get_medcat(project=project)

cat.cdb.save(project.concept_db.cdb_file.path)
if project.concept_db is not None:
# CDB / vocab based
cat.cdb.save(project.concept_db.cdb_file.path, overwrite=True)
else:
_overwrite_model_pack(cat, project.model_pack.path)

return Response({'message': 'Models saved'})


def _overwrite_model_pack(cat, model_path: str):
# NOTE: cannot overwrite, so working around
with TemporaryDirectory() as tmp_dir:
# making new folder name so that it's copied
# to the specific path rather than into the folder
temp_folder = os.path.join(tmp_dir, "model_copy")
shutil.move(model_path, temp_folder)
try:
cat.save_model_pack(
os.path.dirname(model_path),
pack_name=os.path.basename(model_path),
add_hash_to_pack_name=False)
except Exception as e:
logger.warning("Unable to save model pack. Restoring previous state")
if os.path.exists(model_path):
shutil.rmtree(model_path) # remove partial/corrupt output
# restore original
try:
shutil.move(temp_folder, model_path)
except Exception as restore_err:
logger.error("Failed to restore model pack:", exc_info=restore_err)
raise


@api_view(http_method_names=['POST'])
def get_create_entity(request):
label = request.data['label']
Expand Down
Loading