hoshingakag commited on
Commit
37b6839
·
verified ·
1 Parent(s): de866f1

Upload 9 files

Browse files
app.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.config import setup_logging, Config
2
+ from src.embeddings_model import GEmbeddings
3
+ from src.text_generation_model import GLLM
4
+ from src.pinecone_index import PineconeIndex
5
+ from src.llamaindex_backend import GLlamaIndex
6
+
7
+ import gradio as gr
8
+ import google.generativeai as genai
9
+ from llama_index.core import Settings
10
+
11
+ from typing import List
12
+ import time
13
+
14
+ import dotenv
15
+ dotenv.load_dotenv(".env")
16
+
17
+ logger = setup_logging()
18
+
19
+ # Google Generative AI
20
+ genai.configure(api_key=Config.GAI_API_KEY)
21
+
22
+ # Llama-Index LLM
23
+ embed_model = GEmbeddings(model_name=Config.EMB_MODEL_NAME)
24
+ llm = GLLM(model_name=Config.TEXT_MODEL_NAME, system_instruction=None)
25
+
26
+ Settings.embed_model = embed_model
27
+ Settings.llm = llm
28
+
29
+ index = PineconeIndex(api_key=Config.PINECONE_API_KEY, index_name=Config.PC_INDEX_NAME, index_namespace=Config.PC_INDEX_NAMESPACE)
30
+ backend = GLlamaIndex(logger, embed_model, llm, index, Config.SIMILARITY_THRESHOLD)
31
+
32
+ # Gradio
33
+ chat_history = []
34
+
35
+ def clear_chat() -> None:
36
+ global chat_history
37
+ chat_history = []
38
+ return None
39
+
40
+ def get_chat_history(chat_history: List[str]) -> str:
41
+ ind = 0
42
+ formatted_chat_history = ""
43
+ for message in chat_history:
44
+ formatted_chat_history += f"User: \n{message}\n" if ind % 2 == 0 else f"Bot: \n{message}\n"
45
+ ind += 1
46
+ return formatted_chat_history
47
+
48
+ def generate_text(prompt: str, backend: GLlamaIndex):
49
+ global chat_history
50
+
51
+ logger.info("Generating Message...")
52
+ logger.info(f"User Message:\n{prompt}\n")
53
+
54
+ result = backend.generate_text(prompt, chat_history)
55
+ chat_history.append(prompt)
56
+ chat_history.append(result)
57
+
58
+ logger.info(f"Replied Message:\n{result}\n")
59
+ return result
60
+
61
+ if __name__ == "__main__":
62
+ try:
63
+ with gr.Blocks(css=".input textarea {font-size: 16px !important}") as app:
64
+ chatbot = gr.Chatbot(
65
+ bubble_full_width=False,
66
+ container=True,
67
+ show_share_button=False,
68
+ avatar_images=[None, './asset/akag-g-only.png']
69
+ )
70
+ msg = gr.Textbox(
71
+ show_label=False,
72
+ label="Type your message...",
73
+ placeholder="Hi Gerard, can you introduce yourself?",
74
+ container=False,
75
+ elem_classes="input"
76
+ )
77
+ with gr.Row():
78
+ clear = gr.Button("Clear", scale=1)
79
+ send = gr.Button(
80
+ value="",
81
+ variant="primary",
82
+ icon="./asset/send-message.png",
83
+ scale=1
84
+ )
85
+
86
+ def user(user_message, history):
87
+ return "", history + [[user_message, None]]
88
+
89
+ def bot(history):
90
+ bot_message = generate_text(history[-1][0], backend)
91
+ history[-1][1] = ""
92
+ for character in bot_message:
93
+ history[-1][1] += character
94
+ time.sleep(0.01)
95
+ yield history
96
+
97
+ msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
98
+ bot, chatbot, chatbot
99
+ )
100
+ send.click(user, [msg, chatbot], [msg, chatbot], queue=False).then(
101
+ bot, chatbot, chatbot
102
+ )
103
+ clear.click(clear_chat, None, chatbot, queue=False)
104
+
105
+ gr.HTML("""
106
+ <p><center><i>Disclaimer: This RAG app is for demostration only. Hallucination might occur.</i></center></p>
107
+ <p><center>Hosted on 🤗 Spaces | Built with Google Gemini & 🦙 LlamaIndex | Last updated 2025</center></p>
108
+ """)
109
+
110
+ app.queue()
111
+ app.launch()
112
+
113
+ except Exception as e:
114
+ logger.exception(e)
asset/akag-g-only.png ADDED
asset/send-message.png ADDED
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ google-generativeai==0.8.3
2
+ llama-index==0.12.12
3
+ llama-index-vector-stores-pinecone==0.4.2
4
+ transformers==4.48.0
5
+ pinecone-client==5.0.1
6
+ wandb==0.19.2
7
+ # transformers==4.30.2
8
+ # llama-index==0.8.48
9
+ # wandb==0.15.12
src/config.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+
4
+ def setup_logging():
5
+ logging.basicConfig(
6
+ format='%(asctime)s %(message)s',
7
+ datefmt='%Y-%m-%d %I:%M:%S %p',
8
+ level=logging.INFO
9
+ )
10
+ return logging.getLogger('backend')
11
+
12
+ class Config:
13
+ # External services
14
+ GAI_API_KEY = os.environ['GAI_API_KEY']
15
+ PINECONE_API_KEY = os.environ['PINECONE_API_KEY']
16
+ WANDB_API_KEY = os.environ['WANDB_API_KEY']
17
+ WANDB_PROJECT = os.environ['WANDB_PROJECT']
18
+
19
+ # Model settings
20
+ TEXT_MODEL_NAME = os.getenv('TEXT_MODEL_NAME', 'gemini-1.5-flash')
21
+ EMB_MODEL_NAME = os.getenv('EMB_MODEL_NAME', 'models/text-embedding-004')
22
+ PC_INDEX_NAME = os.getenv('PC_INDEX_NAME', 'main-index')
23
+ PC_INDEX_NAMESPACE = os.getenv('PC_INDEX_NAMESPACE', 'main')
24
+ CONTEXT_WINDOW = int(os.getenv('CONTEXT_WINDOW', 32768))
25
+ NUM_OUTPUT = int(os.getenv('NUM_OUTPUT', 4098))
26
+ TEXT_CHUNK_SIZE = int(os.getenv('TEXT_CHUNK_SIZE', 2048))
27
+ TEXT_CHUNK_OVERLAP = int(os.getenv('TEXT_CHUNK_OVERLAP', 200))
28
+ TEXT_CHUNK_OVERLAP_RATIO = float(os.getenv('TEXT_CHUNK_OVERLAP_RATIO', 0.1))
29
+ TEXT_CHUNK_SIZE_LIMIT = os.getenv('TEXT_CHUNK_SIZE_LIMIT', None)
30
+ SIMILARITY_THRESHOLD = float(os.getenv('SIMILARITY_THRESHOLD', 0.7))
src/embeddings_model.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Any
2
+ import google.generativeai as genai
3
+ from llama_index.core.embeddings import BaseEmbedding
4
+
5
+ class GEmbeddings(BaseEmbedding):
6
+ def __init__(
7
+ self,
8
+ model_name: str = 'models/text-embedding-004',
9
+ **kwargs: Any,
10
+ ) -> None:
11
+ super().__init__(**kwargs)
12
+ self._model_name = model_name
13
+
14
+ def gai_embed_content(self, text: str) -> List[float]:
15
+ return genai.embed_content(model=self._model_name, content=text)
16
+
17
+ def _get_query_embedding(self, query: str) -> List[float]:
18
+ embeddings = self.gai_embed_content(query)
19
+ return embeddings['embedding']
20
+
21
+ def _get_text_embedding(self, text: str) -> List[float]:
22
+ embeddings = self.gai_embed_content(text)
23
+ return embeddings['embedding']
24
+
25
+ def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
26
+ embeddings = [
27
+ self.gai_embed_content(text)['embedding'] for text in texts
28
+ ]
29
+ return embeddings
30
+
31
+ async def _aget_query_embedding(self, query: str) -> List[float]:
32
+ return self._get_query_embedding(query)
33
+
34
+ async def _aget_text_embedding(self, text: str) -> List[float]:
35
+ return self._get_text_embedding(text)
src/llamaindex_backend.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.embeddings_model import GEmbeddings
2
+ from src.text_generation_model import GLLM
3
+ from src.pinecone_index import PineconeIndex
4
+
5
+ from typing import Dict, List, Any, Union
6
+ import datetime
7
+ import asyncio
8
+
9
+ from llama_index.core.evaluation import SemanticSimilarityEvaluator
10
+ from llama_index.core.base.embeddings.base import SimilarityMode
11
+
12
+ prompt_template = """
13
+ <system>
14
+ You are in a role play of Gerard Lee. Gerard is a data enthusiast and humble about his success.
15
+ Reply as faifhfully as possible and in no more than 5 complete sentences unless <user query> requests to elaborate in details. Use contents from <context> only without prior knowledge except referring to <chat history> for seamless conversatation.
16
+ </system>
17
+
18
+ <chat history>
19
+ {context_history}
20
+ <chat history>
21
+
22
+ <context>
23
+ {context_from_index}
24
+ </context>
25
+
26
+ <user query>
27
+ {user_query}
28
+ </user query>
29
+ """
30
+
31
+ class GLlamaIndex():
32
+ def __init__(
33
+ self,
34
+ logger,
35
+ emb_model: GEmbeddings,
36
+ text_model: GLLM,
37
+ index: PineconeIndex,
38
+ similarity_threshold: float
39
+ ) -> None:
40
+ self.logger = logger
41
+ self.emb_model = emb_model
42
+ self.llm = text_model
43
+ self.index = index
44
+ self.evaluator = self._set_evaluator(similarity_threshold)
45
+ self.prompt_template = prompt_template
46
+
47
+ def _set_evaluator(self, similarity_threshold: float) -> SemanticSimilarityEvaluator:
48
+ sem_evaluator = SemanticSimilarityEvaluator(
49
+ similarity_mode=SimilarityMode.DEFAULT,
50
+ similarity_threshold=similarity_threshold,
51
+ )
52
+ return sem_evaluator
53
+
54
+ def format_history(self, history: List[str]) -> str:
55
+ return "\n".join(list(filter(None, history)))
56
+
57
+ async def aget_context_with_history(
58
+ self,
59
+ query: str,
60
+ history: List[str]
61
+ ) -> str:
62
+ if not history:
63
+ result = await self.index.retrieve_context(query)
64
+ return result["result"]
65
+
66
+ extended_query = f"[History]\n{history[-1]}\n[New Query]\n{query}"
67
+ results = await self.index.aretrieve_context_multi(
68
+ [query, extended_query]
69
+ )
70
+ print(results)
71
+ eval_results = await self.aevaluate_context_multi(
72
+ [query, extended_query],
73
+ [r["result"] for r in results]
74
+ )
75
+ print(eval_results)
76
+ return results[0]["result"] if eval_results[0].score > eval_results[1].score \
77
+ else results[1]["result"]
78
+
79
+ async def aevaluate_context(
80
+ self,
81
+ query: str,
82
+ returned_context: str
83
+ ) -> Dict[str, Any]:
84
+ result = await self.evaluator.aevaluate(
85
+ response=returned_context,
86
+ reference=query,
87
+ )
88
+ return result
89
+
90
+ async def aevaluate_context_multi(
91
+ self,
92
+ query_list: List[str],
93
+ returned_context_list: List[str]
94
+ ) -> List[Dict]:
95
+ result = await asyncio.gather(*(self.aevaluate_context(query, returned_context) for query, returned_context in zip(query_list, returned_context_list)))
96
+ return result
97
+
98
+ def generate_text(
99
+ self,
100
+ query: str,
101
+ history: List[str],
102
+ ) -> str:
103
+ # get chat history
104
+ context_history = self.format_history(history=history)
105
+
106
+ # get retrieval context(s) from llama-index vectorstore index
107
+ try:
108
+ # without history, single context retrieval without evaluation
109
+ if not history:
110
+ # w&b trace retrieval context
111
+ result_query_only = self.index.retrieve_context(query)
112
+ context_from_index_selected = result_query_only["result"]
113
+
114
+ # with history, multiple context retrieval with async, then evaluation to determine which context to choose
115
+ else:
116
+ context_from_index_selected = asyncio.run(self.aget_context_with_history(query=query, history=history))
117
+
118
+ except Exception as e:
119
+ self.logger.error(f"Exception {e} occured when retriving context\n")
120
+
121
+ llm_end_time_ms = round(datetime.datetime.now().timestamp() * 1000)
122
+ result = "Something went wrong. Please try again later."
123
+ return result
124
+
125
+ self.logger.info(f"Context from Llama-Index:\n{context_from_index_selected}\n")
126
+
127
+ # generate text with prompt template to roleplay myself
128
+ prompt_with_context = self.prompt_template.format(context_history=context_history, context_from_index=context_from_index_selected, user_query=query)
129
+ try:
130
+ result = self.llm.gai_generate_content(
131
+ prompt=prompt_with_context,
132
+ temperature=0.5,
133
+ )
134
+ success_flag = "success"
135
+ if result is None:
136
+ result = "Seems something went wrong. Please try again later."
137
+ self.logger.error(f"Result with 'None' received\n")
138
+ success_flag = "fail"
139
+
140
+ except Exception as e:
141
+ result = "Seems something went wrong. Please try again later."
142
+ self.logger.error(f"Exception {e} occured\n")
143
+ success_flag = "fail"
144
+
145
+ return result
src/pinecone_index.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Union
2
+ import datetime
3
+ import asyncio
4
+ from pinecone import Pinecone
5
+ from llama_index.core import VectorStoreIndex, StorageContext
6
+ from llama_index.vector_stores.pinecone import PineconeVectorStore
7
+
8
+ class PineconeIndex:
9
+ def __init__(self, api_key: str, index_name: str, index_namespace: str):
10
+ self._index_name = index_name
11
+ self._index_namespace = index_namespace
12
+ self._pc = Pinecone(api_key=api_key)
13
+ self.pc_index = self._set_index(index_name, index_namespace)
14
+
15
+ def _set_index(self, index_name: str, index_namespace: str) -> VectorStoreIndex:
16
+ vector_store = PineconeVectorStore(
17
+ pinecone_index=self._pc.Index(index_name),
18
+ add_sparse_vector=True,
19
+ namespace=index_namespace
20
+ )
21
+ storage_context = StorageContext.from_defaults(vector_store=vector_store)
22
+ pc_index = VectorStoreIndex.from_vector_store(vector_store=vector_store, storage_context=storage_context)
23
+ return pc_index
24
+
25
+ def retrieve_context(self, query: str) -> Dict[str, Union[str, int]]:
26
+ start_time = round(datetime.datetime.now().timestamp() * 1000)
27
+ response = self.pc_index.as_query_engine(similarity_top_k=3).query(query)
28
+ end_time = round(datetime.datetime.now().timestamp() * 1000)
29
+ return {"result": response.response, "start": start_time, "end": end_time}
30
+
31
+ async def aretrieve_context(self, query: str) -> Dict[str, Union[str, int]]:
32
+ start_time = round(datetime.datetime.now().timestamp() * 1000)
33
+ response = await self.pc_index.as_query_engine(
34
+ similarity_top_k=3,
35
+ use_async=True
36
+ ).aquery(query)
37
+ end_time = round(datetime.datetime.now().timestamp() * 1000)
38
+ return {"result": response.response, "start": start_time, "end": end_time}
39
+
40
+ async def aretrieve_context_multi(self, query_list: List[str]) -> List[Dict]:
41
+ result = await asyncio.gather(*(self.aretrieve_context(query) for query in query_list))
42
+ return result
src/text_generation_model.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+ import google.generativeai as genai
3
+ from google.generativeai.types import HarmCategory, HarmBlockThreshold
4
+ from llama_index.core.llms import (
5
+ CustomLLM,
6
+ CompletionResponse,
7
+ CompletionResponseGen,
8
+ LLMMetadata,
9
+ )
10
+ from llama_index.core.llms.callbacks import llm_completion_callback
11
+
12
+ class GLLM(CustomLLM):
13
+ def __init__(
14
+ self,
15
+ context_window: int = 32768,
16
+ num_output: int = 4098,
17
+ model_name: str = "gemini-1.5-flash",
18
+ system_instruction: str = None,
19
+ **kwargs: Any,
20
+ ) -> None:
21
+ super().__init__(**kwargs)
22
+ self._context_window = context_window
23
+ self._num_output = num_output
24
+ self._model_name = model_name
25
+ self._model = genai.GenerativeModel(model_name, system_instruction=system_instruction)
26
+
27
+ def gai_generate_content(self, prompt: str, temperature:float =0.5) -> str:
28
+ return self._model.generate_content(
29
+ prompt,
30
+ generation_config = genai.GenerationConfig(
31
+ temperature=temperature,
32
+ ),
33
+ safety_settings={
34
+ HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
35
+ HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
36
+ }
37
+ ).text
38
+
39
+ @property
40
+ def metadata(self) -> LLMMetadata:
41
+ """Get LLM metadata."""
42
+ return LLMMetadata(
43
+ context_window=self._context_window,
44
+ num_output=self._num_output,
45
+ model_name=self._model_name,
46
+ )
47
+
48
+ @llm_completion_callback()
49
+ def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
50
+ text = self.gai_generate_content(prompt)
51
+ return CompletionResponse(text=text)
52
+
53
+ @llm_completion_callback()
54
+ def stream_complete(
55
+ self, prompt: str, **kwargs: Any
56
+ ) -> CompletionResponseGen:
57
+ text = self.gai_generate_content(prompt)
58
+ response = ""
59
+ for token in text:
60
+ response += token
61
+ yield CompletionResponse(text=response, delta=token)