File size: 6,326 Bytes
ab4488b |
|
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
from typing import List, Optional, Dict, cast
import base64
import itertools
import json
from azure.core.paging import ItemPaged, PageIterator, ReturnType
from ._generated.models import SearchRequest, SearchDocumentsResult, QueryAnswerResult
from ._api_versions import DEFAULT_VERSION
def convert_search_result(result):
ret = result.additional_properties
ret["@search.score"] = result.score
ret["@search.reranker_score"] = result.reranker_score
ret["@search.highlights"] = result.highlights
ret["@search.captions"] = result.captions
return ret
def pack_continuation_token(response, api_version=DEFAULT_VERSION):
if response.next_page_parameters is not None:
token = {
"apiVersion": api_version,
"nextLink": response.next_link,
"nextPageParameters": response.next_page_parameters.serialize(),
}
return base64.b64encode(json.dumps(token).encode("utf-8"))
return None
def unpack_continuation_token(token):
unpacked_token = json.loads(base64.b64decode(token))
next_link = unpacked_token["nextLink"]
next_page_parameters = unpacked_token["nextPageParameters"]
next_page_request = SearchRequest.deserialize(next_page_parameters)
return next_link, next_page_request
class SearchItemPaged(ItemPaged[ReturnType]):
def __init__(self, *args, **kwargs) -> None:
super(SearchItemPaged, self).__init__(*args, **kwargs)
self._first_page_iterator_instance: Optional[SearchPageIterator] = None
def __next__(self) -> ReturnType:
if self._page_iterator is None:
first_iterator = self._first_iterator_instance()
self._page_iterator = itertools.chain.from_iterable(first_iterator)
return next(self._page_iterator)
def _first_iterator_instance(self) -> "SearchPageIterator":
if self._first_page_iterator_instance is None:
self._first_page_iterator_instance = cast(SearchPageIterator, self.by_page())
return self._first_page_iterator_instance
def get_facets(self) -> Optional[Dict]:
"""Return any facet results if faceting was requested.
:return: facet results
:rtype: dict or None
"""
return cast(Dict, self._first_iterator_instance().get_facets())
def get_coverage(self) -> float:
"""Return the coverage percentage, if `minimum_coverage` was
specificied for the query.
:return: coverage percentage
:rtype: float
"""
return cast(float, self._first_iterator_instance().get_coverage())
def get_count(self) -> int:
"""Return the count of results if `include_total_count` was
set for the query.
:return: count of results
:rtype: int
"""
return cast(int, self._first_iterator_instance().get_count())
def get_answers(self) -> Optional[List[QueryAnswerResult]]:
"""Return semantic answers. Only included if the semantic ranker is used
and answers are requested in the search query via the query_answer parameter.
:return: answers
:rtype: list[~azure.search.documents.models.QueryAnswerResult] or None
"""
return cast(List[QueryAnswerResult], self._first_iterator_instance().get_answers())
# The pylint error silenced below seems spurious, as the inner wrapper does, in
# fact, become a method of the class when it is applied.
def _ensure_response(f):
# pylint:disable=protected-access
def wrapper(self, *args, **kw):
if self._current_page is None:
self._response = self._get_next(self.continuation_token)
self.continuation_token, self._current_page = self._extract_data(self._response)
return f(self, *args, **kw)
return wrapper
class SearchPageIterator(PageIterator):
def __init__(self, client, initial_query, kwargs, continuation_token=None) -> None:
super(SearchPageIterator, self).__init__(
get_next=self._get_next_cb,
extract_data=self._extract_data_cb,
continuation_token=continuation_token,
)
self._client = client
self._initial_query = initial_query
self._kwargs = kwargs
self._facets = None
self._api_version = kwargs.pop("api_version", DEFAULT_VERSION)
def _get_next_cb(self, continuation_token):
if continuation_token is None:
return self._client.documents.search_post(search_request=self._initial_query.request, **self._kwargs)
_next_link, next_page_request = unpack_continuation_token(continuation_token)
return self._client.documents.search_post(search_request=next_page_request, **self._kwargs)
def _extract_data_cb(self, response):
continuation_token = pack_continuation_token(response, api_version=self._api_version)
results = [convert_search_result(r) for r in response.results]
return continuation_token, results
@_ensure_response
def get_facets(self) -> Optional[Dict]:
self.continuation_token = None
response = cast(SearchDocumentsResult, self._response)
facets = response.facets
if facets is not None and self._facets is None:
assert facets.items() is not None # Hint for mypy
self._facets = {k: [x.as_dict() for x in v] for k, v in facets.items()}
return self._facets
@_ensure_response
def get_coverage(self) -> float:
self.continuation_token = None
response = cast(SearchDocumentsResult, self._response)
return cast(float, response.coverage)
@_ensure_response
def get_count(self) -> int:
self.continuation_token = None
response = cast(SearchDocumentsResult, self._response)
return cast(int, response.count)
@_ensure_response
def get_answers(self) -> Optional[List[QueryAnswerResult]]:
self.continuation_token = None
response = cast(SearchDocumentsResult, self._response)
return response.answers
|