File size: 1,718 Bytes
2358b5a
 
9a2eaeb
7b4a8f1
2358b5a
 
 
 
 
 
 
 
7b4a8f1
2358b5a
7b4a8f1
2358b5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b4a8f1
2358b5a
 
7b4a8f1
2358b5a
 
 
 
 
 
 
 
 
 
 
 
7b4a8f1
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import urllib3
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
import os
from pydantic import Extra
import requests
from typing import Any, List, Mapping, Optional

from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM

class URALLM(LLM):
    llm_url = os.environ.get("URL")

    class Config:
        extra = Extra.forbid

    @property
    def _llm_type(self) -> str:
        return "URALLM"

    def _call(
        self,
        inputs: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> str:
        if stop is not None:
            raise ValueError("stop kwargs are not permitted.")

        payload = {
            "inputs": inputs,
            # "return_full_text":True,
            # "do_sample":True,
            "parameters": {"max_new_tokens":512,
                           "temperature":0.01,
                           "repetition_penalty":1.1,
                           "do_sample":True,
                           "top_k":10
                           }
        }

        headers = {"Content-Type": "application/json", "Authorization": os.environ.get("TOKEN")}

        response = requests.post(self.llm_url, json=payload, headers=headers, verify=False)
        response.raise_for_status()

        # print("API Response:", response.json())

        return response.json()['generated_text']  # get the response from the API
        # return response.json().get('generated_text', '')

    @property
    def _identifying_params(self) -> Mapping[str, Any]:
        """Get the identifying parameters."""
        return {"llmUrl": self.llm_url}