Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
ENVIRONMENT=development
DEBUG=TRUE
# Disables the lookup of language maps on startup (speeds up dev boot)
NO_LM=TRUE
SERVICE_NAME=oclapi2

# -----------------------------------------------------------------------------
Expand Down
127 changes: 79 additions & 48 deletions core/common/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
import time
import urllib

import requests
from cid.locals import get_cid
from django.conf import settings
from django.db.models import Case, When, IntegerField
from elasticsearch_dsl import FacetedSearch, Q
from pydash import compact, get, has, set_
from sentence_transformers import CrossEncoder

from core.common import ERRBIT_LOGGER
from core.common.constants import ES_REQUEST_TIMEOUT
from core.common.utils import is_url_encoded_string

Expand Down Expand Up @@ -336,42 +338,46 @@ def __get_response(self, exact_count=True, load_fields=False):
return self._dsl_search, None, total


class VectorEmbed:
_LOCAL_MODELS = {}

def __init__(self, model_name=None):
self.model_name = model_name or settings.LM_MODEL_NAME

def embed(self, txt):
if settings.EMBEDDING_SERVICE_URL:
return self._get_embedding_from_service(txt)
return self._get_embedding_locally(txt)

def _get_embedding_from_service(self, txt):
try:
response = requests.post(
f'{settings.EMBEDDING_SERVICE_URL}/embeddings',
headers={'Authorization': f'Bearer {settings.INFINITY_API_KEY}'},
json={'model': self.model_name, 'input': str(txt)},
timeout=10
)
response.raise_for_status()
return response.json()['data'][0]['embedding']
except Exception as ex: # pylint: disable=broad-except
ERRBIT_LOGGER.log(ex)
return self._get_embedding_locally(txt)

def _get_embedding_locally(self, txt):
model = self._LOCAL_MODELS.get(self.model_name)
if model is None:
from sentence_transformers import SentenceTransformer
model = self._LOCAL_MODELS[self.model_name] = SentenceTransformer(self.model_name)
return model.encode(str(txt)).tolist()


class Reranker:
ENCODERS = [
# Best and Fastest overall lightweight medical reranker
# Size: ~110M
# Speed: similar to MiniLM CrossEncoder
# Training: includes clinical, medical, question-answering datasets
# Output: positive similarity scores (not raw logits!)
# 0.6B params
# https://huggingface.co/BAAI/bge-reranker-v2-m3
"BAAI/bge-reranker-v2-m3",

# Model: jinhybr/OA-MedBERT-cross-encoder or similar
# Size: ~110M
# Domain: PubMed abstracts, biomedical QA
# Type: binary classifier (logits)
# Not huggin face model -- ???
# "jinhybr/OA-MedBERT-cross-encoder",

# Model: microsoft/BioLinkBERT-base
# Type: CrossEncoder
# Size: ~120M
# Domain: UMLS, PubMed, MeSH, SNOMED (closest to OCL)
# Not huggin face model -- doesn't work with sentence_transformers
# "microsoft/BioLinkBERT-base",

# 22.7M params
# https://huggingface.co/cross-encoder/ms-marco-MiniLM-L6-v2
# doesn't work with logits, so not between 0-1
"cross-encoder/ms-marco-MiniLM-L-6-v2",
]
SCORE_KEY = 'search_rerank_score'
MISSING_SCORE = -1000000.0

def __init__(self, model_name=None):
self.model_name = model_name
self.encoder = self._get_encoder(self.model_name)
self.model_name = model_name or self.default_model
self.encoder = None

def rerank( # pylint: disable=too-many-arguments
self, hits, txt, name_key='name', source_attr=None, should_convert_source_to_dict=True,
Expand All @@ -393,18 +399,53 @@ def _predict_scores(self, hits, txt, name_key, source_attr, should_convert_sourc
return scores_full

docs = [get(self._get_source(hit, source_attr, should_convert_source_to_dict), name_key) for hit in hits]
valid = []
valid_docs = []
for i, d in enumerate(docs):
if isinstance(d, str) and d.strip():
valid.append((i, d.strip()))
if not valid:
valid_docs.append((i, d.strip()))
if not valid_docs:
return scores_full
scores = self.encoder.predict([(txt, d) for _, d in valid])
for (i, _), s in zip(valid, scores):

scores = self._get_rerank_scores(txt, valid_docs)
for (i, _), s in zip(valid_docs, scores):
scores_full[i] = float(s)

return scores_full

def _get_rerank_scores(self, txt, docs):
if settings.EMBEDDING_SERVICE_URL:
return self._get_rerank_scores_from_service(txt, docs)
return self._get_rerank_scores_locally(txt, docs)

def _get_rerank_scores_from_service(self, txt, docs):
try:
response = requests.post(
f'{settings.EMBEDDING_SERVICE_URL}/rerank',
headers={'Authorization': f'Bearer {settings.INFINITY_API_KEY}'},
json={
'model': self.model_name or self.default_model,
'query': txt,
'documents': [d for _, d in docs],
},
timeout=60
)
response.raise_for_status()
results = response.json()['results']
# results is a list of {index, relevance_score} sorted by index
return [r['relevance_score'] for r in sorted(results, key=lambda r: r['index'])]
except Exception as ex: # pylint: disable=broad-except
ERRBIT_LOGGER.log(ex)
return self._get_rerank_scores_locally(txt, docs)

def _get_rerank_scores_locally(self, txt, docs):
try:
if not self.encoder:
self.encoder = self._get_encoder()
return self.encoder.predict([(txt, d) for _, d in docs])
except Exception as ex: # pylint: disable=broad-except
ERRBIT_LOGGER.log(ex)
return [self.MISSING_SCORE] * len(docs)

def _assign_score(self, hits, scores, score_key, order_results):
score_key = score_key or self.SCORE_KEY
key_to_set = score_key
Expand All @@ -420,18 +461,8 @@ def _assign_score(self, hits, scores, score_key, order_results):
def _order(hits, key_to_order):
return sorted(hits, key=lambda hit: get(hit, key_to_order), reverse=True)

def _get_encoder(self, model_name):
if model_name and model_name != self.default_model:
return self._load_encoder(model_name)
return self._load_default_encoder()

@staticmethod
def _load_encoder(model_name):
return CrossEncoder(model_name, device="cpu", max_length=128)

@staticmethod
def _load_default_encoder():
return settings.ENCODER
def _get_encoder(self):
return CrossEncoder(self.model_name, device="cpu", max_length=128)

@staticmethod
def _get_source(data, source_attr, should_convert_source_to_dict):
Expand Down
174 changes: 174 additions & 0 deletions core/common/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1621,3 +1621,177 @@ def test_multi_token_input_expands_each_known_token(self, mock_load, mock_resolv

terms = LexicalVariantDictionary.get_variant_terms('childhood leukaemia colour')
self.assertEqual(set(terms), {'leukemia', 'color'})


class VectorEmbedTest(OCLTestCase):
def setUp(self):
self.embedder = None

@override_settings(EMBEDDING_SERVICE_URL='http://embed-service:8008', INFINITY_API_KEY='test-key')
@patch('requests.post')
def test_embed_uses_service_when_url_configured(self, mock_post):
mock_post.return_value = Mock(
status_code=200,
json=Mock(return_value={'data': [{'embedding': [0.1, 0.2, 0.3]}]})
)
from core.common.search import VectorEmbed
result = VectorEmbed().embed('malaria')
self.assertEqual(result, [0.1, 0.2, 0.3])
mock_post.assert_called_once()
call_kwargs = mock_post.call_args
self.assertIn('/embeddings', call_kwargs[0][0])
self.assertEqual(call_kwargs[1]['headers']['Authorization'], 'Bearer test-key')
self.assertEqual(call_kwargs[1]['json']['input'], 'malaria')

@override_settings(EMBEDDING_SERVICE_URL='http://embed-service:8008', INFINITY_API_KEY='')
@patch('requests.post')
def test_embed_falls_back_to_local_on_service_error(self, mock_post):
mock_post.side_effect = Exception('connection refused')
from core.common.search import VectorEmbed
embedder = VectorEmbed()
with patch.object(embedder, '_get_embedding_locally', return_value=[0.4, 0.5]) as mock_local:
result = embedder.embed('diabetes')
mock_local.assert_called_once_with('diabetes')
self.assertEqual(result, [0.4, 0.5])

@override_settings(EMBEDDING_SERVICE_URL='')
def test_embed_uses_local_when_no_service_url(self):
from core.common.search import VectorEmbed
embedder = VectorEmbed()
with patch.object(embedder, '_get_embedding_locally', return_value=[0.1, 0.2]) as mock_local:
result = embedder.embed('hypertension')
mock_local.assert_called_once_with('hypertension')
self.assertEqual(result, [0.1, 0.2])

@override_settings(EMBEDDING_SERVICE_URL='http://embed-service:8008', INFINITY_API_KEY='')
@patch('requests.post')
def test_embed_uses_custom_model_name(self, mock_post):
mock_post.return_value = Mock(
status_code=200,
json=Mock(return_value={'data': [{'embedding': [0.9]}]})
)
from core.common.search import VectorEmbed
VectorEmbed(model_name='custom/model').embed('test')
self.assertEqual(mock_post.call_args[1]['json']['model'], 'custom/model')


class RerankerTest(OCLTestCase):
def _make_hit(self, name, use_search_meta=False):
if use_search_meta:
return {'name': name, 'search_meta': {}}
return {'name': name}

@override_settings(EMBEDDING_SERVICE_URL='http://embed-service:8008', INFINITY_API_KEY='test-key')
@patch('requests.post')
def test_rerank_uses_service_when_url_configured(self, mock_post):
mock_post.return_value = Mock(
status_code=200,
json=Mock(return_value={
'results': [
{'index': 0, 'relevance_score': 0.9},
{'index': 1, 'relevance_score': 0.3},
]
})
)
from core.common.search import Reranker
hits = [self._make_hit('malaria fever'), self._make_hit('diabetes')]
result = Reranker().rerank(hits, 'malaria', order_results=False)
self.assertEqual(result[0]['search_rerank_score'], 0.9)
self.assertEqual(result[1]['search_rerank_score'], 0.3)
mock_post.assert_called_once()
self.assertIn('/rerank', mock_post.call_args[0][0])

@override_settings(EMBEDDING_SERVICE_URL='http://embed-service:8008', INFINITY_API_KEY='test-key')
@patch('requests.post')
def test_rerank_falls_back_to_local_on_service_error(self, mock_post):
mock_post.side_effect = Exception('timeout')
from core.common.search import Reranker
reranker = Reranker()
hits = [self._make_hit('malaria')]
with patch.object(reranker, '_get_rerank_scores_locally', return_value=[0.7]) as mock_local:
reranker.rerank(hits, 'malaria', order_results=False)
mock_local.assert_called_once()

@override_settings(EMBEDDING_SERVICE_URL='http://embed-service:8008', INFINITY_API_KEY='test-key')
@patch('requests.post')
def test_rerank_returns_missing_score_when_both_fail(self, mock_post):
mock_post.side_effect = Exception('timeout')
from core.common.search import Reranker
reranker = Reranker()
with patch.object(reranker, '_get_rerank_scores_locally', return_value=[Reranker.MISSING_SCORE]):
hits = [self._make_hit('malaria')]
result = reranker.rerank(hits, 'malaria', order_results=False)
self.assertEqual(result[0]['search_rerank_score'], Reranker.MISSING_SCORE)

@override_settings(EMBEDDING_SERVICE_URL='')
def test_rerank_uses_local_when_no_service_url(self):
from core.common.search import Reranker
reranker = Reranker()
hits = [self._make_hit('malaria')]
with patch.object(reranker, '_get_rerank_scores_locally', return_value=[0.8]) as mock_local:
reranker.rerank(hits, 'malaria', order_results=False)
mock_local.assert_called_once()

def test_rerank_returns_empty_on_no_hits(self):
from core.common.search import Reranker
result = Reranker().rerank([], 'malaria')
self.assertEqual(result, [])

def test_rerank_returns_missing_score_on_blank_query(self):
from core.common.search import Reranker
hits = [self._make_hit('malaria')]
result = Reranker().rerank(hits, ' ', order_results=False)
self.assertEqual(result[0]['search_rerank_score'], Reranker.MISSING_SCORE)

def test_rerank_skips_hits_with_missing_name(self):
from core.common.search import Reranker
reranker = Reranker()
hits = [self._make_hit(''), self._make_hit('malaria')]
with patch.object(reranker, '_get_rerank_scores_locally', return_value=[0.9]) as mock_local:
result = reranker.rerank(hits, 'malaria', order_results=False)
# only one valid doc passed to scorer
self.assertEqual(len(mock_local.call_args[0][1]), 1)
self.assertEqual(result[0]['search_rerank_score'], Reranker.MISSING_SCORE)
self.assertEqual(result[1]['search_rerank_score'], 0.9)

@override_settings(EMBEDDING_SERVICE_URL='http://embed-service:8008', INFINITY_API_KEY='test-key')
@patch('requests.post')
def test_rerank_orders_results_by_score_descending(self, mock_post):
mock_post.return_value = Mock(
status_code=200,
json=Mock(return_value={
'results': [
{'index': 0, 'relevance_score': 0.2},
{'index': 1, 'relevance_score': 0.8},
]
})
)
from core.common.search import Reranker
hits = [self._make_hit('diabetes'), self._make_hit('malaria fever')]
result = Reranker().rerank(hits, 'malaria')
self.assertEqual(result[0]['name'], 'malaria fever')
self.assertEqual(result[1]['name'], 'diabetes')

@override_settings(EMBEDDING_SERVICE_URL='http://embed-service:8008', INFINITY_API_KEY='test-key')
@patch('requests.post')
def test_rerank_assigns_score_to_search_meta_when_present(self, mock_post):
mock_post.return_value = Mock(
status_code=200,
json=Mock(return_value={'results': [{'index': 0, 'relevance_score': 0.5}]})
)
from core.common.search import Reranker
hits = [self._make_hit('malaria', use_search_meta=True)]
result = Reranker().rerank(hits, 'malaria', order_results=False)
self.assertEqual(result[0]['search_meta']['search_rerank_score'], 0.5)
self.assertEqual(result[0]['search_meta']['search_normalized_score'], 50.0)

@override_settings(EMBEDDING_SERVICE_URL='http://embed-service:8008', INFINITY_API_KEY='test-key')
@patch('requests.post')
def test_rerank_uses_custom_model(self, mock_post):
mock_post.return_value = Mock(
status_code=200,
json=Mock(return_value={'results': [{'index': 0, 'relevance_score': 0.6}]})
)
from core.common.search import Reranker
Reranker(model_name='custom/reranker').rerank([self._make_hit('malaria')], 'malaria')
self.assertEqual(mock_post.call_args[1]['json']['model'], 'custom/reranker')
12 changes: 0 additions & 12 deletions core/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -918,15 +918,3 @@ def format_url_for_search(url):

def clean_term(term):
return term.lower().replace(' ', '').replace('-', '').replace('_', '')


def get_embeddings(txt):
from core.toggles.models import Toggle
if not Toggle.get('SEMANTIC_SEARCH_TOGGLE') or settings.ENV == 'ci':
return None

model = settings.LM
if not model:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer(settings.LM_MODEL_NAME)
return model.encode(str(txt))
Loading