from langchain_core.retrievers import BaseRetriever
from langchain_core.documents import Document
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from typing import List, Optional, Any, Union
import json
class SimpleKVStoreRetriever(BaseRetriever):
"""A simple retriever that retrieves documents based on a prefix match in the key-value store."""
store: BigtableByteStore
documents: List[Union[Document, str]]
k: int
def set_up_store(self):
kv_pairs_to_set = []
for i, doc in enumerate(self.documents):
if isinstance(doc, str):
doc = Document(page_content=doc)
if not doc.id:
doc.id = str(i)
value = (
"Page Content\n"
+ doc.page_content
+ "\nMetadata"
+ json.dumps(doc.metadata)
)
kv_pairs_to_set.append((doc.id, value.encode("utf-8")))
self.store.mset(kv_pairs_to_set)
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
) -> List[Document]:
keys = [key async for key in self.store.ayield_keys(prefix=query)][: self.k]
documents_retrieved = []
async for document in await self.store.amget(keys):
if document:
document_str = document.decode("utf-8")
page_content = document_str.split("Content\n")[1].split("\nMetadata")[0]
metadata = json.loads(document_str.split("\nMetadata")[1])
documents_retrieved.append(
Document(page_content=page_content, metadata=metadata)
)
return documents_retrieved
def _get_relevant_documents(
self,
query: str,
*,
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
) -> list[Document]:
keys = [key for key in self.store.yield_keys(prefix=query)][: self.k]
documents_retrieved = []
for document in self.store.mget(keys):
if document:
document_str = document.decode("utf-8")
page_content = document_str.split("Content\n")[1].split("\nMetadata")[0]
metadata = json.loads(document_str.split("\nMetadata")[1])
documents_retrieved.append(
Document(page_content=page_content, metadata=metadata)
)
return documents_retrieved