Zwea Htet commited on
Commit
069e494
·
1 Parent(s): 8013479

update llama custom

Browse files
Files changed (1) hide show
  1. models/llamaCustom.py +7 -7
models/llamaCustom.py CHANGED
@@ -38,10 +38,10 @@ CHUNK_OVERLAP_RATION = 0.2
38
 
39
 
40
  @st.cache_resource
41
- def load_model(mode_name: str):
42
  # llm_model_name = "bigscience/bloom-560m"
43
- tokenizer = AutoTokenizer.from_pretrained(mode_name)
44
- model = AutoModelForCausalLM.from_pretrained(mode_name, config="T5Config")
45
 
46
  pipe = pipeline(
47
  task="text-generation",
@@ -62,11 +62,11 @@ class CustomLLM(LLM):
62
  llm_model_name: str
63
  pipeline: Any
64
 
65
- def __init__(self, model_name: str):
66
- # super().__init__()
67
 
68
- self.llm_model_name = model_name
69
- self.pipeline = load_model(mode_name=model_name)
70
 
71
  def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
72
  prompt_length = len(prompt)
 
38
 
39
 
40
  @st.cache_resource
41
+ def load_model(model_name: str):
42
  # llm_model_name = "bigscience/bloom-560m"
43
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
44
+ model = AutoModelForCausalLM.from_pretrained(model_name, config="T5Config")
45
 
46
  pipe = pipeline(
47
  task="text-generation",
 
62
  llm_model_name: str
63
  pipeline: Any
64
 
65
+ def __init__(self, llm_model_name: str):
66
+ super().__init__()
67
 
68
+ self.llm_model_name = llm_model_name
69
+ self.pipeline = load_model(mode_name=llm_model_name)
70
 
71
  def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
72
  prompt_length = len(prompt)