Ashmi Banerjee
first draft, with the gcp bucket
4b722ec
raw
history blame
2.26 kB
from augmentation import prompt_generation as pg
from information_retrieval import info_retrieval as ir
from src.text_generation.models import (
Llama3,
Mistral,
Gemma2,
Llama3Point1,
GPT4,
Claude3Point5Sonnet,
)
import logging
logger = logging.getLogger(__name__)
logging.basicConfig(encoding='utf-8', level=logging.DEBUG)
def generate_response(model, prompt):
"""
Function that initializes the LLM class and calls the generate function.
Args:
- messages: list; contains the system and user prompt
- model: class; the class of the llm to be initialized
"""
logger.info(f"Initializing LLM configuration for {model}")
llm = model()
logger.info("Generating response")
try:
response = llm.generate(prompt)
except Exception as e:
logger.error(f"Error while generating response for {model}: {e}")
response = 'ERROR'
return response
def test(model):
context_params = {
'limit': 3,
'reranking': 0
}
# model = Llama3Point1
query = "Suggest some places to visit during winter. I like hiking, nature and the mountains and I enjoy skiing " \
"in winter. "
# without sustainability
logger.info("Retrieving context..")
try:
context = ir.get_context(query=query, **context_params)
except Exception as e:
logger.error(f"Error while trying to get context: {e}")
return None
logger.info("Retrieved context, augmenting prompt (without sustainability)..")
try:
without_sfairness = pg.augment_prompt(
query=query,
context=context,
sustainability=0,
params=context_params
)
except Exception as e:
logger.error(f"Error while trying to augment prompt: {e}")
return None
# return without_sfairness
logger.info(f"Augmented prompt, initializing {model} and generating response..")
try:
response = generate_response(model, without_sfairness)
except Exception as e:
logger.info(f"Error while generating response: {e}")
return None
return response
if __name__ == "__main__":
response = test(Claude3Point5Sonnet)
print(response)