File size: 5,784 Bytes
ca67025 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 | import os
import logging
from pinecone import Pinecone, ServerlessSpec
from dotenv import load_dotenv
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_pinecone import PineconeVectorStore
from langchain_community.retrievers import BM25Retriever
from langchain_community.document_loaders import PyPDFLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
load_dotenv()
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
INDEX_NAME = "finance-rag"
class RagPipeline:
def __init__(self, index_name=INDEX_NAME, embedding_model="BAAI/bge-base-en-v1.5"):
api_key = os.getenv("PINECONE_API_KEY")
if not api_key:
raise ValueError("PINECONE_API_KEY not found in environment variables.")
self.pc = Pinecone(api_key=api_key)
self.index_name = index_name
self.bm25_retriever = None
self.cached_docs = [] # FIX
self._ensure_index()
logger.info("Loading embedding model...")
self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model)
logger.info("Embedding model loaded successfully.")
def _ensure_index(self):
existing_indexes = self.pc.list_indexes().names()
if self.index_name not in existing_indexes:
logger.info(f"Creating Pinecone index: {self.index_name}")
self.pc.create_index(
name=self.index_name,
dimension=768,
metric="cosine",
spec=ServerlessSpec(
cloud="aws",
region="us-east-1"
)
)
logger.info("Pinecone index created successfully.")
else:
logger.info(f"Pinecone index '{self.index_name}' already exists.")
def vector_store(self):
return PineconeVectorStore(
index=self.pc.Index(self.index_name),
embedding=self.embeddings
)
def load_docs(self, pdf_path: str):
try:
logger.info(f"Loading PDF: {pdf_path}")
loader = PyPDFLoader(pdf_path)
documents = loader.load()
logger.info(f"Loaded {len(documents)} pages.")
return documents
except Exception as e:
logger.exception("Error loading PDF.")
raise e
def split_docs(self, docs, chunk_size=1000, chunk_overlap=250):
try:
logger.info("Splitting documents into chunks...")
splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap
)
split_documents = splitter.split_documents(docs)
self.cached_docs = split_documents # FIX
logger.info(f"Generated {len(split_documents)} chunks.")
return split_documents
except Exception:
logger.exception("Error splitting documents.")
return None
def add_docs(self, split_docs):
try:
logger.info("Uploading chunks to Pinecone...")
vectorstore = self.vector_store()
vectorstore.add_documents(split_docs)
logger.info("Documents uploaded to Pinecone successfully.")
except Exception:
logger.exception("Error uploading documents to Pinecone.")
def delete_all_docs(self):
try:
logger.info("Deleting ALL documents from Pinecone index...")
index = self.pc.Index(self.index_name)
index.delete(delete_all=True)
logger.info("All documents deleted successfully.")
self.bm25_retriever = None
self.cached_docs = [] # FIX
except Exception:
logger.exception("Error deleting all documents.")
def create_bm25(self, split_docs=None, k=4): # FIX
try:
logger.info("Creating BM25 retriever...")
docs = split_docs if split_docs is not None else self.cached_docs
self.bm25_retriever = BM25Retriever.from_documents(docs)
self.bm25_retriever.k = k
logger.info("BM25 retriever ready.")
except Exception:
logger.exception("Error creating BM25 retriever.")
def dense_retriever(self, k=4):
vectorstore = self.vector_store()
return vectorstore.as_retriever(
search_kwargs={"k": k}
)
def hybrid_retrieve(self, query, dense_k=4, top_k=6):
try:
dense_docs = []
try:
dense_docs = self.dense_retriever(k=dense_k).invoke(query)
except Exception:
logger.warning("Dense retrieval unavailable.")
bm25_docs = []
if self.bm25_retriever is None:
if self.cached_docs:
logger.info("Rebuilding BM25 retriever.")
self.create_bm25()
else:
logger.warning("No uploaded docs found. Using direct LLM fallback.")
return []
if self.bm25_retriever:
bm25_docs = self.bm25_retriever.invoke(query)
combined = bm25_docs + dense_docs
seen = set()
unique = []
for doc in combined:
text = doc.page_content.strip()
if text not in seen:
seen.add(text)
unique.append(doc)
return unique[:top_k]
except Exception:
logger.exception("Error during hybrid retrieval.")
return [] |