Skip to content

Foil hole suggestions for a given grid square#296

Open
d-j-hatton wants to merge 3 commits into
mainfrom
feature/foilhole-suggestions
Open

Foil hole suggestions for a given grid square#296
d-j-hatton wants to merge 3 commits into
mainfrom
feature/foilhole-suggestions

Conversation

@d-j-hatton
Copy link
Copy Markdown
Contributor

Adds add an endpoint to make suggestions about holes to collect on a specified grid square accounting for foil hole score and distribution across different foil hole types as determined from a latent space representation (provided externally to this package). A larger number of holes are selected per latent space cluster when compared to the grid square code to account for the fact that it is unlikely every cluster will appear on a given grid square (clustering is across the whole grid).

Addresses #295

…fied grid square accounting for foil hole score and distribution across different foil hole types as determined from a latent space representation (provided externally to this package)
@d-j-hatton d-j-hatton added the enhancement Minor improvements to existing functionality label Jun 3, 2026
Copy link
Copy Markdown
Collaborator

@vredchenko vredchenko left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this - the structure closely mirrors get_suggested_square_collections, which is the right approach. A few divergences from that reference function look like they would cause the endpoint to fail on most real inputs, and there is no test coverage to catch it. Details inline; the central one is that scores here holds (FoilHole, CurrentQualityPrediction) Row tuples that are sorted and indexed as if they were FoilHole objects, whereas the square endpoint extracts p[0] and sorts on x[1].value first.

Worth adding while here:

  • A unit test for this endpoint (and ideally backfilling one for the square version) - it is a pure ranking function over DB rows, so it should be cheap to test and would pin down the row handling and the per-cluster cap.
  • Module-level named constants for the two magic numbers (see inline), e.g. near the top of api_server.py:
HOLE_SELECTION_FRACTION = 2  # consider the top 1/N of scored holes
MAX_HOLES_PER_CLUSTER = 4

The inline suggestion uses these names. The same rationale (4 here vs 2 for grid squares) could later be applied to the square endpoint for consistency, but that is out of scope for this PR.

Marking as a comment rather than a blocker.

Comment thread src/smartem_backend/api_server.py Outdated
async def get_suggested_hole_collections(
gridsquare_uuid: str, prediction_model_name: str, latent_rep_model_name: str, db: AsyncSession = DB_DEPENDENCY
):
gridsquare = (await db.execute(select(GridSquare).where(GridSquare.uuid == gridsquare_uuid))).one()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.one() returns a Row, not a GridSquare, so gridsquare.grid_uuid on the next line raises AttributeError. Everywhere else in api_server.py uses .scalar_one() / .scalars() for single-entity fetches - this is the only bare .one(). It runs before the loop, so it would fail on every request regardless of hole count.

Suggested change
gridsquare = (await db.execute(select(GridSquare).where(GridSquare.uuid == gridsquare_uuid))).one()
gridsquare = (await db.execute(select(GridSquare).where(GridSquare.uuid == gridsquare_uuid))).scalar_one()

Comment thread src/smartem_backend/api_server.py Outdated
Comment on lines +2644 to +2652
scores.sort(reverse=True)
cluster_counts = dict.fromkeys(set(cluster_indices.values()), 0)
suggested = []
for i in range(len(scores) // 2):
hole = scores[i]
if cluster_counts[cluster_indices[hole.uuid]] < 4:
suggested.append(hole)
cluster_counts[cluster_indices[hole.uuid]] += 1
return suggested
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scores is a list of (FoilHole, CurrentQualityPrediction) Row tuples, which causes three problems here:

  1. scores.sort(reverse=True) compares rows tuple-wise, hitting the FoilHole objects first. FoilHole is a SQLModel with no __lt__, so this raises TypeError as soon as there are two or more holes. It also means the prediction value is never read - the square endpoint sorts with key=lambda x: x[1].value, so as written the ranking-by-score is lost entirely, which contradicts the PR description.
  2. hole = scores[i] is a Row, so hole.uuid is an AttributeError (row keys are entity-level, not column-level).
  3. suggested.append(hole) appends the Row against response_model=list[FoilHole], which will not serialise.

Adopting the square endpoint's extract-then-sort idiom fixes all three. This also guards the cluster_indices lookup (otherwise a hole with no cluster-index row raises KeyError) and uses the named constants from the summary:

Suggested change
scores.sort(reverse=True)
cluster_counts = dict.fromkeys(set(cluster_indices.values()), 0)
suggested = []
for i in range(len(scores) // 2):
hole = scores[i]
if cluster_counts[cluster_indices[hole.uuid]] < 4:
suggested.append(hole)
cluster_counts[cluster_indices[hole.uuid]] += 1
return suggested
score_ordered_holes = [p[0] for p in sorted(scores, key=lambda x: x[1].value, reverse=True)]
cluster_counts = dict.fromkeys(set(cluster_indices.values()), 0)
suggested = []
for hole in score_ordered_holes[: len(score_ordered_holes) // HOLE_SELECTION_FRACTION]:
cluster = cluster_indices.get(hole.uuid)
if cluster is not None and cluster_counts[cluster] < MAX_HOLES_PER_CLUSTER:
suggested.append(hole)
cluster_counts[cluster] += 1
return suggested

Note the suggestion references HOLE_SELECTION_FRACTION and MAX_HOLES_PER_CLUSTER; add those at module level (see summary) before applying, or drop in the literals 2 and 4 if you would rather not.

select(QualityPredictionModelParameter)
.where(QualityPredictionModelParameter.grid_uuid == grid_uuid)
.where(QualityPredictionModelParameter.prediction_model_name == latent_rep_model_name)
.where(QualityPredictionModelParameter.group == "cluster_indices")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This filters the cluster parameters by grid_uuid and latent-rep model only, reusing the same "cluster_indices" group as the grid-square endpoint. Could you confirm the externally-provided parameters here are keyed by foil-hole uuid and will not collide with the grid-square-level cluster indices stored under the same group and grid? If both levels share the group, a distinct group name or an extra discriminator may be needed.

@vredchenko
Copy link
Copy Markdown
Collaborator

@d-j-hatton as you may have gathered the above review is me unleashing Claude Code on the PR, it's known to make mistakes. I shall attend in person later, focusing on clearing f/e feature backlog at present

@vredchenko
Copy link
Copy Markdown
Collaborator

Also, as we merge it we need to be mindful that we've changed the API source of truth - to that end we'll need a version bump resulting in a new release, which should kick off automation to update f/e client and devtools repo docs - leave it with me to verify it all works as intended

@d-j-hatton
Copy link
Copy Markdown
Contributor Author

The sorting issue was a genuine bug. I also changed the .one() behaviour to .scalar_one() to align better with the rest of the code but haven't checked whether .one() works or not with the async session (I'm fairly sure that it would work normally)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement Minor improvements to existing functionality

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants