File size: 879 Bytes
976b948
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
"""NVIDIA LLM Implementation"""

from llama_index.llms.nvidia import NVIDIA

from src.llm.base_llm_provider import BaseLLMProvider
from src.llm.enums import DEFAULT_LLM_API_BASE


class NvidiaLLM(BaseLLMProvider):
    def __init__(

        self,

        model: str = "nvidia/llama-3.1-nemotron-70b-instruct",

        temperature: float = 0.0,

        base_url: str = "https://integrate.api.nvidia.com/v1",

    ):
        """Initiate NVIDIA client"""

        if base_url == DEFAULT_LLM_API_BASE:
            self._client = NVIDIA(
                model=model,
                temperature=temperature,
            )
        else:
            self._client = NVIDIA(
                model=model, temperature=temperature, base_url=base_url
            )

    def complete(self, prompt: str = "") -> str:
        return str(self._client.complete(prompt))