File size: 1,652 Bytes
8a41f4d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from application.llm.base import BaseLLM
from application.core.settings import settings

class AnthropicLLM(BaseLLM):

    def __init__(self, api_key=None):
        from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT
        self.api_key = api_key or settings.ANTHROPIC_API_KEY  # If not provided, use a default from settings
        self.anthropic = Anthropic(api_key=self.api_key)
        self.HUMAN_PROMPT = HUMAN_PROMPT
        self.AI_PROMPT = AI_PROMPT

    def gen(self, model, messages, engine=None, max_tokens=300, stream=False, **kwargs):
        context = messages[0]['content']
        user_question = messages[-1]['content']
        prompt = f"### Context \n {context} \n ### Question \n {user_question}"
        if stream:
            return self.gen_stream(model, prompt, max_tokens, **kwargs)

        completion = self.anthropic.completions.create(
            model=model,
            max_tokens_to_sample=max_tokens,
            stream=stream,
            prompt=f"{self.HUMAN_PROMPT} {prompt}{self.AI_PROMPT}",
        )
        return completion.completion

    def gen_stream(self, model, messages, engine=None, max_tokens=300, **kwargs):
        context = messages[0]['content']
        user_question = messages[-1]['content']
        prompt = f"### Context \n {context} \n ### Question \n {user_question}"
        stream_response = self.anthropic.completions.create(
            model=model,
            prompt=f"{self.HUMAN_PROMPT} {prompt}{self.AI_PROMPT}",
            max_tokens_to_sample=max_tokens,
            stream=True,
        )

        for completion in stream_response:
            yield completion.completion