Spaces:
Sleeping
Sleeping
File size: 2,087 Bytes
37b6839 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
from typing import Any
import google.generativeai as genai
from google.generativeai.types import HarmCategory, HarmBlockThreshold
from llama_index.core.llms import (
CustomLLM,
CompletionResponse,
CompletionResponseGen,
LLMMetadata,
)
from llama_index.core.llms.callbacks import llm_completion_callback
class GLLM(CustomLLM):
def __init__(
self,
context_window: int = 32768,
num_output: int = 4098,
model_name: str = "gemini-1.5-flash",
system_instruction: str = None,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self._context_window = context_window
self._num_output = num_output
self._model_name = model_name
self._model = genai.GenerativeModel(model_name, system_instruction=system_instruction)
def gai_generate_content(self, prompt: str, temperature:float =0.5) -> str:
return self._model.generate_content(
prompt,
generation_config = genai.GenerationConfig(
temperature=temperature,
),
safety_settings={
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
}
).text
@property
def metadata(self) -> LLMMetadata:
"""Get LLM metadata."""
return LLMMetadata(
context_window=self._context_window,
num_output=self._num_output,
model_name=self._model_name,
)
@llm_completion_callback()
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
text = self.gai_generate_content(prompt)
return CompletionResponse(text=text)
@llm_completion_callback()
def stream_complete(
self, prompt: str, **kwargs: Any
) -> CompletionResponseGen:
text = self.gai_generate_content(prompt)
response = ""
for token in text:
response += token
yield CompletionResponse(text=response, delta=token) |