File size: 5,784 Bytes
ca67025
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import os
import logging
from pinecone import Pinecone, ServerlessSpec
from dotenv import load_dotenv
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_pinecone import PineconeVectorStore
from langchain_community.retrievers import BM25Retriever
from langchain_community.document_loaders import PyPDFLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter

load_dotenv()

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

INDEX_NAME = "finance-rag"


class RagPipeline:

    def __init__(self, index_name=INDEX_NAME, embedding_model="BAAI/bge-base-en-v1.5"):
        api_key = os.getenv("PINECONE_API_KEY")

        if not api_key:
            raise ValueError("PINECONE_API_KEY not found in environment variables.")

        self.pc = Pinecone(api_key=api_key)
        self.index_name = index_name
        self.bm25_retriever = None
        self.cached_docs = []   # FIX

        self._ensure_index()

        logger.info("Loading embedding model...")
        self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model)
        logger.info("Embedding model loaded successfully.")

    def _ensure_index(self):
        existing_indexes = self.pc.list_indexes().names()

        if self.index_name not in existing_indexes:
            logger.info(f"Creating Pinecone index: {self.index_name}")

            self.pc.create_index(
                name=self.index_name,
                dimension=768,
                metric="cosine",
                spec=ServerlessSpec(
                    cloud="aws",
                    region="us-east-1"
                )
            )

            logger.info("Pinecone index created successfully.")

        else:
            logger.info(f"Pinecone index '{self.index_name}' already exists.")

    def vector_store(self):
        return PineconeVectorStore(
            index=self.pc.Index(self.index_name),
            embedding=self.embeddings
        )

    def load_docs(self, pdf_path: str):
        try:
            logger.info(f"Loading PDF: {pdf_path}")

            loader = PyPDFLoader(pdf_path)
            documents = loader.load()

            logger.info(f"Loaded {len(documents)} pages.")

            return documents

        except Exception as e:
            logger.exception("Error loading PDF.")
            raise e

    def split_docs(self, docs, chunk_size=1000, chunk_overlap=250):
        try:
            logger.info("Splitting documents into chunks...")

            splitter = RecursiveCharacterTextSplitter(
                chunk_size=chunk_size,
                chunk_overlap=chunk_overlap
            )

            split_documents = splitter.split_documents(docs)

            self.cached_docs = split_documents   # FIX

            logger.info(f"Generated {len(split_documents)} chunks.")

            return split_documents

        except Exception:
            logger.exception("Error splitting documents.")
            return None

    def add_docs(self, split_docs):
        try:
            logger.info("Uploading chunks to Pinecone...")

            vectorstore = self.vector_store()
            vectorstore.add_documents(split_docs)

            logger.info("Documents uploaded to Pinecone successfully.")

        except Exception:
            logger.exception("Error uploading documents to Pinecone.")

    def delete_all_docs(self):
        try:
            logger.info("Deleting ALL documents from Pinecone index...")

            index = self.pc.Index(self.index_name)

            index.delete(delete_all=True)

            logger.info("All documents deleted successfully.")

            self.bm25_retriever = None
            self.cached_docs = []   # FIX

        except Exception:
            logger.exception("Error deleting all documents.")

    def create_bm25(self, split_docs=None, k=4):   # FIX
        try:
            logger.info("Creating BM25 retriever...")

            docs = split_docs if split_docs is not None else self.cached_docs

            self.bm25_retriever = BM25Retriever.from_documents(docs)
            self.bm25_retriever.k = k

            logger.info("BM25 retriever ready.")

        except Exception:
            logger.exception("Error creating BM25 retriever.")

    def dense_retriever(self, k=4):
        vectorstore = self.vector_store()

        return vectorstore.as_retriever(
            search_kwargs={"k": k}
        )

    def hybrid_retrieve(self, query, dense_k=4, top_k=6):
        try:
            dense_docs = []

            try:
                dense_docs = self.dense_retriever(k=dense_k).invoke(query)
            except Exception:
                logger.warning("Dense retrieval unavailable.")

            bm25_docs = []

            if self.bm25_retriever is None:
                if self.cached_docs:
                    logger.info("Rebuilding BM25 retriever.")
                    self.create_bm25()
                else:
                    logger.warning("No uploaded docs found. Using direct LLM fallback.")
                    return []

            if self.bm25_retriever:
                bm25_docs = self.bm25_retriever.invoke(query)

            combined = bm25_docs + dense_docs

            seen = set()
            unique = []

            for doc in combined:
                text = doc.page_content.strip()

                if text not in seen:
                    seen.add(text)
                    unique.append(doc)

            return unique[:top_k]

        except Exception:
            logger.exception("Error during hybrid retrieval.")
            return []