Spaces:
Running
Running
hoshingakag
commited on
Upload 9 files
Browse files- app.py +114 -0
- asset/akag-g-only.png +0 -0
- asset/send-message.png +0 -0
- requirements.txt +9 -0
- src/config.py +30 -0
- src/embeddings_model.py +35 -0
- src/llamaindex_backend.py +145 -0
- src/pinecone_index.py +42 -0
- src/text_generation_model.py +61 -0
app.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.config import setup_logging, Config
|
2 |
+
from src.embeddings_model import GEmbeddings
|
3 |
+
from src.text_generation_model import GLLM
|
4 |
+
from src.pinecone_index import PineconeIndex
|
5 |
+
from src.llamaindex_backend import GLlamaIndex
|
6 |
+
|
7 |
+
import gradio as gr
|
8 |
+
import google.generativeai as genai
|
9 |
+
from llama_index.core import Settings
|
10 |
+
|
11 |
+
from typing import List
|
12 |
+
import time
|
13 |
+
|
14 |
+
import dotenv
|
15 |
+
dotenv.load_dotenv(".env")
|
16 |
+
|
17 |
+
logger = setup_logging()
|
18 |
+
|
19 |
+
# Google Generative AI
|
20 |
+
genai.configure(api_key=Config.GAI_API_KEY)
|
21 |
+
|
22 |
+
# Llama-Index LLM
|
23 |
+
embed_model = GEmbeddings(model_name=Config.EMB_MODEL_NAME)
|
24 |
+
llm = GLLM(model_name=Config.TEXT_MODEL_NAME, system_instruction=None)
|
25 |
+
|
26 |
+
Settings.embed_model = embed_model
|
27 |
+
Settings.llm = llm
|
28 |
+
|
29 |
+
index = PineconeIndex(api_key=Config.PINECONE_API_KEY, index_name=Config.PC_INDEX_NAME, index_namespace=Config.PC_INDEX_NAMESPACE)
|
30 |
+
backend = GLlamaIndex(logger, embed_model, llm, index, Config.SIMILARITY_THRESHOLD)
|
31 |
+
|
32 |
+
# Gradio
|
33 |
+
chat_history = []
|
34 |
+
|
35 |
+
def clear_chat() -> None:
|
36 |
+
global chat_history
|
37 |
+
chat_history = []
|
38 |
+
return None
|
39 |
+
|
40 |
+
def get_chat_history(chat_history: List[str]) -> str:
|
41 |
+
ind = 0
|
42 |
+
formatted_chat_history = ""
|
43 |
+
for message in chat_history:
|
44 |
+
formatted_chat_history += f"User: \n{message}\n" if ind % 2 == 0 else f"Bot: \n{message}\n"
|
45 |
+
ind += 1
|
46 |
+
return formatted_chat_history
|
47 |
+
|
48 |
+
def generate_text(prompt: str, backend: GLlamaIndex):
|
49 |
+
global chat_history
|
50 |
+
|
51 |
+
logger.info("Generating Message...")
|
52 |
+
logger.info(f"User Message:\n{prompt}\n")
|
53 |
+
|
54 |
+
result = backend.generate_text(prompt, chat_history)
|
55 |
+
chat_history.append(prompt)
|
56 |
+
chat_history.append(result)
|
57 |
+
|
58 |
+
logger.info(f"Replied Message:\n{result}\n")
|
59 |
+
return result
|
60 |
+
|
61 |
+
if __name__ == "__main__":
|
62 |
+
try:
|
63 |
+
with gr.Blocks(css=".input textarea {font-size: 16px !important}") as app:
|
64 |
+
chatbot = gr.Chatbot(
|
65 |
+
bubble_full_width=False,
|
66 |
+
container=True,
|
67 |
+
show_share_button=False,
|
68 |
+
avatar_images=[None, './asset/akag-g-only.png']
|
69 |
+
)
|
70 |
+
msg = gr.Textbox(
|
71 |
+
show_label=False,
|
72 |
+
label="Type your message...",
|
73 |
+
placeholder="Hi Gerard, can you introduce yourself?",
|
74 |
+
container=False,
|
75 |
+
elem_classes="input"
|
76 |
+
)
|
77 |
+
with gr.Row():
|
78 |
+
clear = gr.Button("Clear", scale=1)
|
79 |
+
send = gr.Button(
|
80 |
+
value="",
|
81 |
+
variant="primary",
|
82 |
+
icon="./asset/send-message.png",
|
83 |
+
scale=1
|
84 |
+
)
|
85 |
+
|
86 |
+
def user(user_message, history):
|
87 |
+
return "", history + [[user_message, None]]
|
88 |
+
|
89 |
+
def bot(history):
|
90 |
+
bot_message = generate_text(history[-1][0], backend)
|
91 |
+
history[-1][1] = ""
|
92 |
+
for character in bot_message:
|
93 |
+
history[-1][1] += character
|
94 |
+
time.sleep(0.01)
|
95 |
+
yield history
|
96 |
+
|
97 |
+
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
|
98 |
+
bot, chatbot, chatbot
|
99 |
+
)
|
100 |
+
send.click(user, [msg, chatbot], [msg, chatbot], queue=False).then(
|
101 |
+
bot, chatbot, chatbot
|
102 |
+
)
|
103 |
+
clear.click(clear_chat, None, chatbot, queue=False)
|
104 |
+
|
105 |
+
gr.HTML("""
|
106 |
+
<p><center><i>Disclaimer: This RAG app is for demostration only. Hallucination might occur.</i></center></p>
|
107 |
+
<p><center>Hosted on 🤗 Spaces | Built with Google Gemini & 🦙 LlamaIndex | Last updated 2025</center></p>
|
108 |
+
""")
|
109 |
+
|
110 |
+
app.queue()
|
111 |
+
app.launch()
|
112 |
+
|
113 |
+
except Exception as e:
|
114 |
+
logger.exception(e)
|
asset/akag-g-only.png
ADDED
asset/send-message.png
ADDED
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
google-generativeai==0.8.3
|
2 |
+
llama-index==0.12.12
|
3 |
+
llama-index-vector-stores-pinecone==0.4.2
|
4 |
+
transformers==4.48.0
|
5 |
+
pinecone-client==5.0.1
|
6 |
+
wandb==0.19.2
|
7 |
+
# transformers==4.30.2
|
8 |
+
# llama-index==0.8.48
|
9 |
+
# wandb==0.15.12
|
src/config.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import logging
|
3 |
+
|
4 |
+
def setup_logging():
|
5 |
+
logging.basicConfig(
|
6 |
+
format='%(asctime)s %(message)s',
|
7 |
+
datefmt='%Y-%m-%d %I:%M:%S %p',
|
8 |
+
level=logging.INFO
|
9 |
+
)
|
10 |
+
return logging.getLogger('backend')
|
11 |
+
|
12 |
+
class Config:
|
13 |
+
# External services
|
14 |
+
GAI_API_KEY = os.environ['GAI_API_KEY']
|
15 |
+
PINECONE_API_KEY = os.environ['PINECONE_API_KEY']
|
16 |
+
WANDB_API_KEY = os.environ['WANDB_API_KEY']
|
17 |
+
WANDB_PROJECT = os.environ['WANDB_PROJECT']
|
18 |
+
|
19 |
+
# Model settings
|
20 |
+
TEXT_MODEL_NAME = os.getenv('TEXT_MODEL_NAME', 'gemini-1.5-flash')
|
21 |
+
EMB_MODEL_NAME = os.getenv('EMB_MODEL_NAME', 'models/text-embedding-004')
|
22 |
+
PC_INDEX_NAME = os.getenv('PC_INDEX_NAME', 'main-index')
|
23 |
+
PC_INDEX_NAMESPACE = os.getenv('PC_INDEX_NAMESPACE', 'main')
|
24 |
+
CONTEXT_WINDOW = int(os.getenv('CONTEXT_WINDOW', 32768))
|
25 |
+
NUM_OUTPUT = int(os.getenv('NUM_OUTPUT', 4098))
|
26 |
+
TEXT_CHUNK_SIZE = int(os.getenv('TEXT_CHUNK_SIZE', 2048))
|
27 |
+
TEXT_CHUNK_OVERLAP = int(os.getenv('TEXT_CHUNK_OVERLAP', 200))
|
28 |
+
TEXT_CHUNK_OVERLAP_RATIO = float(os.getenv('TEXT_CHUNK_OVERLAP_RATIO', 0.1))
|
29 |
+
TEXT_CHUNK_SIZE_LIMIT = os.getenv('TEXT_CHUNK_SIZE_LIMIT', None)
|
30 |
+
SIMILARITY_THRESHOLD = float(os.getenv('SIMILARITY_THRESHOLD', 0.7))
|
src/embeddings_model.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Any
|
2 |
+
import google.generativeai as genai
|
3 |
+
from llama_index.core.embeddings import BaseEmbedding
|
4 |
+
|
5 |
+
class GEmbeddings(BaseEmbedding):
|
6 |
+
def __init__(
|
7 |
+
self,
|
8 |
+
model_name: str = 'models/text-embedding-004',
|
9 |
+
**kwargs: Any,
|
10 |
+
) -> None:
|
11 |
+
super().__init__(**kwargs)
|
12 |
+
self._model_name = model_name
|
13 |
+
|
14 |
+
def gai_embed_content(self, text: str) -> List[float]:
|
15 |
+
return genai.embed_content(model=self._model_name, content=text)
|
16 |
+
|
17 |
+
def _get_query_embedding(self, query: str) -> List[float]:
|
18 |
+
embeddings = self.gai_embed_content(query)
|
19 |
+
return embeddings['embedding']
|
20 |
+
|
21 |
+
def _get_text_embedding(self, text: str) -> List[float]:
|
22 |
+
embeddings = self.gai_embed_content(text)
|
23 |
+
return embeddings['embedding']
|
24 |
+
|
25 |
+
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
|
26 |
+
embeddings = [
|
27 |
+
self.gai_embed_content(text)['embedding'] for text in texts
|
28 |
+
]
|
29 |
+
return embeddings
|
30 |
+
|
31 |
+
async def _aget_query_embedding(self, query: str) -> List[float]:
|
32 |
+
return self._get_query_embedding(query)
|
33 |
+
|
34 |
+
async def _aget_text_embedding(self, text: str) -> List[float]:
|
35 |
+
return self._get_text_embedding(text)
|
src/llamaindex_backend.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.embeddings_model import GEmbeddings
|
2 |
+
from src.text_generation_model import GLLM
|
3 |
+
from src.pinecone_index import PineconeIndex
|
4 |
+
|
5 |
+
from typing import Dict, List, Any, Union
|
6 |
+
import datetime
|
7 |
+
import asyncio
|
8 |
+
|
9 |
+
from llama_index.core.evaluation import SemanticSimilarityEvaluator
|
10 |
+
from llama_index.core.base.embeddings.base import SimilarityMode
|
11 |
+
|
12 |
+
prompt_template = """
|
13 |
+
<system>
|
14 |
+
You are in a role play of Gerard Lee. Gerard is a data enthusiast and humble about his success.
|
15 |
+
Reply as faifhfully as possible and in no more than 5 complete sentences unless <user query> requests to elaborate in details. Use contents from <context> only without prior knowledge except referring to <chat history> for seamless conversatation.
|
16 |
+
</system>
|
17 |
+
|
18 |
+
<chat history>
|
19 |
+
{context_history}
|
20 |
+
<chat history>
|
21 |
+
|
22 |
+
<context>
|
23 |
+
{context_from_index}
|
24 |
+
</context>
|
25 |
+
|
26 |
+
<user query>
|
27 |
+
{user_query}
|
28 |
+
</user query>
|
29 |
+
"""
|
30 |
+
|
31 |
+
class GLlamaIndex():
|
32 |
+
def __init__(
|
33 |
+
self,
|
34 |
+
logger,
|
35 |
+
emb_model: GEmbeddings,
|
36 |
+
text_model: GLLM,
|
37 |
+
index: PineconeIndex,
|
38 |
+
similarity_threshold: float
|
39 |
+
) -> None:
|
40 |
+
self.logger = logger
|
41 |
+
self.emb_model = emb_model
|
42 |
+
self.llm = text_model
|
43 |
+
self.index = index
|
44 |
+
self.evaluator = self._set_evaluator(similarity_threshold)
|
45 |
+
self.prompt_template = prompt_template
|
46 |
+
|
47 |
+
def _set_evaluator(self, similarity_threshold: float) -> SemanticSimilarityEvaluator:
|
48 |
+
sem_evaluator = SemanticSimilarityEvaluator(
|
49 |
+
similarity_mode=SimilarityMode.DEFAULT,
|
50 |
+
similarity_threshold=similarity_threshold,
|
51 |
+
)
|
52 |
+
return sem_evaluator
|
53 |
+
|
54 |
+
def format_history(self, history: List[str]) -> str:
|
55 |
+
return "\n".join(list(filter(None, history)))
|
56 |
+
|
57 |
+
async def aget_context_with_history(
|
58 |
+
self,
|
59 |
+
query: str,
|
60 |
+
history: List[str]
|
61 |
+
) -> str:
|
62 |
+
if not history:
|
63 |
+
result = await self.index.retrieve_context(query)
|
64 |
+
return result["result"]
|
65 |
+
|
66 |
+
extended_query = f"[History]\n{history[-1]}\n[New Query]\n{query}"
|
67 |
+
results = await self.index.aretrieve_context_multi(
|
68 |
+
[query, extended_query]
|
69 |
+
)
|
70 |
+
print(results)
|
71 |
+
eval_results = await self.aevaluate_context_multi(
|
72 |
+
[query, extended_query],
|
73 |
+
[r["result"] for r in results]
|
74 |
+
)
|
75 |
+
print(eval_results)
|
76 |
+
return results[0]["result"] if eval_results[0].score > eval_results[1].score \
|
77 |
+
else results[1]["result"]
|
78 |
+
|
79 |
+
async def aevaluate_context(
|
80 |
+
self,
|
81 |
+
query: str,
|
82 |
+
returned_context: str
|
83 |
+
) -> Dict[str, Any]:
|
84 |
+
result = await self.evaluator.aevaluate(
|
85 |
+
response=returned_context,
|
86 |
+
reference=query,
|
87 |
+
)
|
88 |
+
return result
|
89 |
+
|
90 |
+
async def aevaluate_context_multi(
|
91 |
+
self,
|
92 |
+
query_list: List[str],
|
93 |
+
returned_context_list: List[str]
|
94 |
+
) -> List[Dict]:
|
95 |
+
result = await asyncio.gather(*(self.aevaluate_context(query, returned_context) for query, returned_context in zip(query_list, returned_context_list)))
|
96 |
+
return result
|
97 |
+
|
98 |
+
def generate_text(
|
99 |
+
self,
|
100 |
+
query: str,
|
101 |
+
history: List[str],
|
102 |
+
) -> str:
|
103 |
+
# get chat history
|
104 |
+
context_history = self.format_history(history=history)
|
105 |
+
|
106 |
+
# get retrieval context(s) from llama-index vectorstore index
|
107 |
+
try:
|
108 |
+
# without history, single context retrieval without evaluation
|
109 |
+
if not history:
|
110 |
+
# w&b trace retrieval context
|
111 |
+
result_query_only = self.index.retrieve_context(query)
|
112 |
+
context_from_index_selected = result_query_only["result"]
|
113 |
+
|
114 |
+
# with history, multiple context retrieval with async, then evaluation to determine which context to choose
|
115 |
+
else:
|
116 |
+
context_from_index_selected = asyncio.run(self.aget_context_with_history(query=query, history=history))
|
117 |
+
|
118 |
+
except Exception as e:
|
119 |
+
self.logger.error(f"Exception {e} occured when retriving context\n")
|
120 |
+
|
121 |
+
llm_end_time_ms = round(datetime.datetime.now().timestamp() * 1000)
|
122 |
+
result = "Something went wrong. Please try again later."
|
123 |
+
return result
|
124 |
+
|
125 |
+
self.logger.info(f"Context from Llama-Index:\n{context_from_index_selected}\n")
|
126 |
+
|
127 |
+
# generate text with prompt template to roleplay myself
|
128 |
+
prompt_with_context = self.prompt_template.format(context_history=context_history, context_from_index=context_from_index_selected, user_query=query)
|
129 |
+
try:
|
130 |
+
result = self.llm.gai_generate_content(
|
131 |
+
prompt=prompt_with_context,
|
132 |
+
temperature=0.5,
|
133 |
+
)
|
134 |
+
success_flag = "success"
|
135 |
+
if result is None:
|
136 |
+
result = "Seems something went wrong. Please try again later."
|
137 |
+
self.logger.error(f"Result with 'None' received\n")
|
138 |
+
success_flag = "fail"
|
139 |
+
|
140 |
+
except Exception as e:
|
141 |
+
result = "Seems something went wrong. Please try again later."
|
142 |
+
self.logger.error(f"Exception {e} occured\n")
|
143 |
+
success_flag = "fail"
|
144 |
+
|
145 |
+
return result
|
src/pinecone_index.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Dict, Union
|
2 |
+
import datetime
|
3 |
+
import asyncio
|
4 |
+
from pinecone import Pinecone
|
5 |
+
from llama_index.core import VectorStoreIndex, StorageContext
|
6 |
+
from llama_index.vector_stores.pinecone import PineconeVectorStore
|
7 |
+
|
8 |
+
class PineconeIndex:
|
9 |
+
def __init__(self, api_key: str, index_name: str, index_namespace: str):
|
10 |
+
self._index_name = index_name
|
11 |
+
self._index_namespace = index_namespace
|
12 |
+
self._pc = Pinecone(api_key=api_key)
|
13 |
+
self.pc_index = self._set_index(index_name, index_namespace)
|
14 |
+
|
15 |
+
def _set_index(self, index_name: str, index_namespace: str) -> VectorStoreIndex:
|
16 |
+
vector_store = PineconeVectorStore(
|
17 |
+
pinecone_index=self._pc.Index(index_name),
|
18 |
+
add_sparse_vector=True,
|
19 |
+
namespace=index_namespace
|
20 |
+
)
|
21 |
+
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
22 |
+
pc_index = VectorStoreIndex.from_vector_store(vector_store=vector_store, storage_context=storage_context)
|
23 |
+
return pc_index
|
24 |
+
|
25 |
+
def retrieve_context(self, query: str) -> Dict[str, Union[str, int]]:
|
26 |
+
start_time = round(datetime.datetime.now().timestamp() * 1000)
|
27 |
+
response = self.pc_index.as_query_engine(similarity_top_k=3).query(query)
|
28 |
+
end_time = round(datetime.datetime.now().timestamp() * 1000)
|
29 |
+
return {"result": response.response, "start": start_time, "end": end_time}
|
30 |
+
|
31 |
+
async def aretrieve_context(self, query: str) -> Dict[str, Union[str, int]]:
|
32 |
+
start_time = round(datetime.datetime.now().timestamp() * 1000)
|
33 |
+
response = await self.pc_index.as_query_engine(
|
34 |
+
similarity_top_k=3,
|
35 |
+
use_async=True
|
36 |
+
).aquery(query)
|
37 |
+
end_time = round(datetime.datetime.now().timestamp() * 1000)
|
38 |
+
return {"result": response.response, "start": start_time, "end": end_time}
|
39 |
+
|
40 |
+
async def aretrieve_context_multi(self, query_list: List[str]) -> List[Dict]:
|
41 |
+
result = await asyncio.gather(*(self.aretrieve_context(query) for query in query_list))
|
42 |
+
return result
|
src/text_generation_model.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any
|
2 |
+
import google.generativeai as genai
|
3 |
+
from google.generativeai.types import HarmCategory, HarmBlockThreshold
|
4 |
+
from llama_index.core.llms import (
|
5 |
+
CustomLLM,
|
6 |
+
CompletionResponse,
|
7 |
+
CompletionResponseGen,
|
8 |
+
LLMMetadata,
|
9 |
+
)
|
10 |
+
from llama_index.core.llms.callbacks import llm_completion_callback
|
11 |
+
|
12 |
+
class GLLM(CustomLLM):
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
context_window: int = 32768,
|
16 |
+
num_output: int = 4098,
|
17 |
+
model_name: str = "gemini-1.5-flash",
|
18 |
+
system_instruction: str = None,
|
19 |
+
**kwargs: Any,
|
20 |
+
) -> None:
|
21 |
+
super().__init__(**kwargs)
|
22 |
+
self._context_window = context_window
|
23 |
+
self._num_output = num_output
|
24 |
+
self._model_name = model_name
|
25 |
+
self._model = genai.GenerativeModel(model_name, system_instruction=system_instruction)
|
26 |
+
|
27 |
+
def gai_generate_content(self, prompt: str, temperature:float =0.5) -> str:
|
28 |
+
return self._model.generate_content(
|
29 |
+
prompt,
|
30 |
+
generation_config = genai.GenerationConfig(
|
31 |
+
temperature=temperature,
|
32 |
+
),
|
33 |
+
safety_settings={
|
34 |
+
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
|
35 |
+
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
|
36 |
+
}
|
37 |
+
).text
|
38 |
+
|
39 |
+
@property
|
40 |
+
def metadata(self) -> LLMMetadata:
|
41 |
+
"""Get LLM metadata."""
|
42 |
+
return LLMMetadata(
|
43 |
+
context_window=self._context_window,
|
44 |
+
num_output=self._num_output,
|
45 |
+
model_name=self._model_name,
|
46 |
+
)
|
47 |
+
|
48 |
+
@llm_completion_callback()
|
49 |
+
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
|
50 |
+
text = self.gai_generate_content(prompt)
|
51 |
+
return CompletionResponse(text=text)
|
52 |
+
|
53 |
+
@llm_completion_callback()
|
54 |
+
def stream_complete(
|
55 |
+
self, prompt: str, **kwargs: Any
|
56 |
+
) -> CompletionResponseGen:
|
57 |
+
text = self.gai_generate_content(prompt)
|
58 |
+
response = ""
|
59 |
+
for token in text:
|
60 |
+
response += token
|
61 |
+
yield CompletionResponse(text=response, delta=token)
|