Spaces:
Building
Building
from collections.abc import Generator | |
import google.generativeai.types.generation_types as generation_config_types | |
import pytest | |
from _pytest.monkeypatch import MonkeyPatch | |
from google.ai import generativelanguage as glm | |
from google.ai.generativelanguage_v1beta.types import content as gag_content | |
from google.generativeai import GenerativeModel | |
from google.generativeai.client import _ClientManager, configure | |
from google.generativeai.types import GenerateContentResponse, content_types, safety_types | |
from google.generativeai.types.generation_types import BaseGenerateContentResponse | |
current_api_key = "" | |
class MockGoogleResponseClass: | |
_done = False | |
def __iter__(self): | |
full_response_text = "it's google!" | |
for i in range(0, len(full_response_text) + 1, 1): | |
if i == len(full_response_text): | |
self._done = True | |
yield GenerateContentResponse( | |
done=True, iterator=None, result=glm.GenerateContentResponse({}), chunks=[] | |
) | |
else: | |
yield GenerateContentResponse( | |
done=False, iterator=None, result=glm.GenerateContentResponse({}), chunks=[] | |
) | |
class MockGoogleResponseCandidateClass: | |
finish_reason = "stop" | |
def content(self) -> gag_content.Content: | |
return gag_content.Content(parts=[gag_content.Part(text="it's google!")]) | |
class MockGoogleClass: | |
def generate_content_sync() -> GenerateContentResponse: | |
return GenerateContentResponse(done=True, iterator=None, result=glm.GenerateContentResponse({}), chunks=[]) | |
def generate_content_stream() -> Generator[GenerateContentResponse, None, None]: | |
return MockGoogleResponseClass() | |
def generate_content( | |
self: GenerativeModel, | |
contents: content_types.ContentsType, | |
*, | |
generation_config: generation_config_types.GenerationConfigType | None = None, | |
safety_settings: safety_types.SafetySettingOptions | None = None, | |
stream: bool = False, | |
**kwargs, | |
) -> GenerateContentResponse: | |
global current_api_key | |
if len(current_api_key) < 16: | |
raise Exception("Invalid API key") | |
if stream: | |
return MockGoogleClass.generate_content_stream() | |
return MockGoogleClass.generate_content_sync() | |
def generative_response_text(self) -> str: | |
return "it's google!" | |
def generative_response_candidates(self) -> list[MockGoogleResponseCandidateClass]: | |
return [MockGoogleResponseCandidateClass()] | |
def make_client(self: _ClientManager, name: str): | |
global current_api_key | |
if name.endswith("_async"): | |
name = name.split("_")[0] | |
cls = getattr(glm, name.title() + "ServiceAsyncClient") | |
else: | |
cls = getattr(glm, name.title() + "ServiceClient") | |
# Attempt to configure using defaults. | |
if not self.client_config: | |
configure() | |
client_options = self.client_config.get("client_options", None) | |
if client_options: | |
current_api_key = client_options.api_key | |
def nop(self, *args, **kwargs): | |
pass | |
original_init = cls.__init__ | |
cls.__init__ = nop | |
client: glm.GenerativeServiceClient = cls(**self.client_config) | |
cls.__init__ = original_init | |
if not self.default_metadata: | |
return client | |
def setup_google_mock(request, monkeypatch: MonkeyPatch): | |
monkeypatch.setattr(BaseGenerateContentResponse, "text", MockGoogleClass.generative_response_text) | |
monkeypatch.setattr(BaseGenerateContentResponse, "candidates", MockGoogleClass.generative_response_candidates) | |
monkeypatch.setattr(GenerativeModel, "generate_content", MockGoogleClass.generate_content) | |
monkeypatch.setattr(_ClientManager, "make_client", MockGoogleClass.make_client) | |
yield | |
monkeypatch.undo() | |