File size: 2,259 Bytes
4b722ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from augmentation import prompt_generation as pg
from information_retrieval import info_retrieval as ir
from src.text_generation.models import (
    Llama3,
    Mistral,
    Gemma2,
    Llama3Point1,
    GPT4,
    Claude3Point5Sonnet,
)
import logging

logger = logging.getLogger(__name__)
logging.basicConfig(encoding='utf-8', level=logging.DEBUG)


def generate_response(model, prompt):
    """
    
    Function that initializes the LLM class and calls the generate function.

    Args:
        - messages: list; contains the system and user prompt 
        - model: class; the class of the llm to be initialized
    
    """

    logger.info(f"Initializing LLM configuration for {model}")
    llm = model()

    logger.info("Generating response")
    try:
        response = llm.generate(prompt)
    except Exception as e:
        logger.error(f"Error while generating response for {model}: {e}")
        response = 'ERROR'

    return response


def test(model):
    context_params = {
        'limit': 3,
        'reranking': 0
    }
    # model = Llama3Point1

    query = "Suggest some places to visit during winter. I like hiking, nature and the mountains and I enjoy skiing " \
            "in winter. "

    # without sustainability
    logger.info("Retrieving context..")
    try:
        context = ir.get_context(query=query, **context_params)
    except Exception as e:
        logger.error(f"Error while trying to get context: {e}")
        return None

    logger.info("Retrieved context, augmenting prompt (without sustainability)..")
    try:
        without_sfairness = pg.augment_prompt(
            query=query,
            context=context,
            sustainability=0,
            params=context_params
        )
    except Exception as e:
        logger.error(f"Error while trying to augment prompt: {e}")
        return None

    # return without_sfairness

    logger.info(f"Augmented prompt, initializing {model} and generating response..")
    try:
        response = generate_response(model, without_sfairness)
    except Exception as e:
        logger.info(f"Error while generating response: {e}")
        return None

    return response


if __name__ == "__main__":
    response = test(Claude3Point5Sonnet)
    print(response)