ngrigg commited on
Commit
c654f8e
1 Parent(s): 2d0dc09
Files changed (1) hide show
  1. llama_models.py +2 -2
llama_models.py CHANGED
@@ -1,5 +1,5 @@
1
  import os
2
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM # Ensure correct model class
3
  import aiohttp
4
 
5
  HUGGINGFACE_API_KEY = os.getenv("HUGGINGFACE_API_KEY")
@@ -11,7 +11,7 @@ def load_model(model_name):
11
  if not tokenizer or not model:
12
  print("Loading model and tokenizer...")
13
  tokenizer = AutoTokenizer.from_pretrained(model_name)
14
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name) # Ensure correct model class
15
  print("Model and tokenizer loaded successfully.")
16
  return tokenizer, model
17
 
 
1
  import os
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM # Ensure correct model class
3
  import aiohttp
4
 
5
  HUGGINGFACE_API_KEY = os.getenv("HUGGINGFACE_API_KEY")
 
11
  if not tokenizer or not model:
12
  print("Loading model and tokenizer...")
13
  tokenizer = AutoTokenizer.from_pretrained(model_name)
14
+ model = AutoModelForCausalLM.from_pretrained(model_name) # Ensure correct model class
15
  print("Model and tokenizer loaded successfully.")
16
  return tokenizer, model
17