| """Load models to use them as a narrator and a common-sense oracle in the PAYADOR pipeline.""" | |
| import google.generativeai as genai | |
| import requests | |
| import replicate | |
| import os | |
| def get_llm(model_name: str = "gemini-1.5-flash") -> object: | |
| google_models = ["gemini-1.0-pro", "gemini-1.5-pro", "gemini-1.5-flash"] | |
| replicate_models = ["meta/meta-llama-3-70b", "meta/meta-llama-3-70b-instruct"] | |
| model = None | |
| if model_name in google_models: | |
| model = GeminiModel(API_key="GOOGLE_API_KEY", model_name=model_name) | |
| elif model_name in replicate_models: | |
| model = ReplicateModel(API_key="REPLICATE_API_TOKEN", model_name=model_name) | |
| return model | |
| class ReplicateModel(): | |
| def __init__ (self, API_key:str, model_name:str = "meta/meta-llama-3-70b-instruct") -> None: | |
| self.temperature = 0.7 | |
| self.model_name = model_name | |
| def prompt_model(self,system_msg: str, user_msg:str) -> str: | |
| """Prompt the Replicate model.""" | |
| system_instructions = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_msg}<|eot_id|><|start_header_id|>user<|end_header_id|>" | |
| input = { | |
| "top_p": 0.1, | |
| "min_tokens": 0, | |
| "temperature": self.temperature, | |
| "prompt": user_msg, | |
| "prompt_template": system_instructions + "\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", | |
| } | |
| output = replicate.run(self.model_name,input=input) | |
| return "".join(output) | |
| class GeminiModel(): | |
| def __init__ (self, API_key:str, model_name:str = "gemini-1.0-pro") -> None: | |
| """"Initialize the Gemini model using an API key.""" | |
| self.safety_settings = [ | |
| { | |
| "category": "HARM_CATEGORY_DANGEROUS", | |
| "threshold": "BLOCK_NONE", | |
| }, | |
| { | |
| "category": "HARM_CATEGORY_HARASSMENT", | |
| "threshold": "BLOCK_NONE", | |
| }, | |
| { | |
| "category": "HARM_CATEGORY_HATE_SPEECH", | |
| "threshold": "BLOCK_NONE", | |
| }, | |
| { | |
| "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", | |
| "threshold": "BLOCK_NONE", | |
| }, | |
| { | |
| "category": "HARM_CATEGORY_DANGEROUS_CONTENT", | |
| "threshold": "BLOCK_NONE", | |
| }, | |
| ] | |
| genai.configure(api_key=os.getenv("GOOGLE_API_KEY")) | |
| self.model = genai.GenerativeModel(model_name) | |
| def prompt_model(self,system_msg: str, user_msg:str) -> str: | |
| """Prompt the Gemini model.""" | |
| return self.model.generate_content(system_msg + "\n\n" + user_msg, safety_settings=self.safety_settings).text | |