Blair Yang commited on
Commit
487b80b
·
1 Parent(s): eabbcfc
Config.py CHANGED
@@ -26,11 +26,9 @@ MODELS = [
26
 
27
  RANDOM_SEED = 42
28
 
29
-
30
-
31
-
32
- # DEFAULT_SUMMARIZER = "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO"
33
- DEFAULT_SUMMARIZER = 'mistralai/Mistral-7B-Instruct-v0.2'
34
  DEFAULT_DATASET = "mmlu"
35
  DEFAULT_TOPIC = random.choice(TOPICS[DEFAULT_DATASET])
36
 
 
26
 
27
  RANDOM_SEED = 42
28
 
29
+ DEFAULT_SUMMARIZER = "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO"
30
+ # DEFAULT_SUMMARIZER = 'NousResearch/Nous-Hermes-2-Mistral-7B-DPO'
31
+ # DEFAULT_SUMMARIZER = 'mistralai/Mistral-7B-Instruct-v0.2'
 
 
32
  DEFAULT_DATASET = "mmlu"
33
  DEFAULT_TOPIC = random.choice(TOPICS[DEFAULT_DATASET])
34
 
__pycache__/Config.cpython-311.pyc CHANGED
Binary files a/__pycache__/Config.cpython-311.pyc and b/__pycache__/Config.cpython-311.pyc differ
 
__pycache__/models.cpython-311.pyc CHANGED
Binary files a/__pycache__/models.cpython-311.pyc and b/__pycache__/models.cpython-311.pyc differ
 
models.py CHANGED
@@ -3,6 +3,7 @@ from __future__ import annotations
3
  import json
4
  import random
5
  import re
 
6
  from abc import ABC, abstractmethod
7
  from typing import List, Dict, Union, Optional
8
 
@@ -10,12 +11,15 @@ from huggingface_hub import InferenceClient
10
  from tenacity import retry, stop_after_attempt, wait_random_exponential
11
  from transformers import AutoTokenizer
12
 
 
 
13
  ROLE_SYSTEM = 'system'
14
  ROLE_USER = 'user'
15
  ROLE_ASSISTANT = 'assistant'
16
 
17
  SUPPORTED_MISTRAL_MODELS = ['mistralai/Mixtral-8x7B-Instruct-v0.1', 'mistralai/Mistral-7B-Instruct-v0.2']
18
- SUPPORTED_NOUS_MODELS = ['NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO']
 
19
  SUPPORTED_LLAMA_MODELS = ['meta-llama/Llama-2-70b-chat-hf',
20
  'meta-llama/Llama-2-13b-chat-hf',
21
  'meta-llama/Llama-2-7b-chat-hf']
@@ -93,7 +97,8 @@ class HFAPIModel(Model):
93
 
94
  @retry(stop=stop_after_attempt(5), wait=wait_random_exponential(max=10), reraise=True) # retry if exception
95
  def get_response(self, temperature: float, use_json: bool, timeout: float, cache: bool) -> str:
96
- client = InferenceClient(model=self.name, timeout=timeout)
 
97
  # client = InferenceClient(model=self.name, token=random.choice(HF_API_TOKENS), timeout=timeout)
98
  if not cache:
99
  client.headers["x-use-cache"] = "0"
@@ -156,7 +161,7 @@ class MistralModel(HFAPIModel):
156
  class NousHermesModel(HFAPIModel):
157
 
158
  def __init__(self, system_prompt: str, model_name: str = 'NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO') -> None:
159
- assert model_name in ['NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO'], 'Model not supported'
160
  super().__init__(model_name, system_prompt)
161
 
162
  def format_messages(self) -> str:
@@ -200,4 +205,4 @@ class LlamaModel(HFAPIModel):
200
  r += f'{content}</s>'
201
  else:
202
  raise ValueError
203
- return r
 
3
  import json
4
  import random
5
  import re
6
+ import os
7
  from abc import ABC, abstractmethod
8
  from typing import List, Dict, Union, Optional
9
 
 
11
  from tenacity import retry, stop_after_attempt, wait_random_exponential
12
  from transformers import AutoTokenizer
13
 
14
+ # from config import *
15
+
16
  ROLE_SYSTEM = 'system'
17
  ROLE_USER = 'user'
18
  ROLE_ASSISTANT = 'assistant'
19
 
20
  SUPPORTED_MISTRAL_MODELS = ['mistralai/Mixtral-8x7B-Instruct-v0.1', 'mistralai/Mistral-7B-Instruct-v0.2']
21
+ SUPPORTED_NOUS_MODELS = ['NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO',
22
+ 'NousResearch/Nous-Hermes-2-Mistral-7B-DPO']
23
  SUPPORTED_LLAMA_MODELS = ['meta-llama/Llama-2-70b-chat-hf',
24
  'meta-llama/Llama-2-13b-chat-hf',
25
  'meta-llama/Llama-2-7b-chat-hf']
 
97
 
98
  @retry(stop=stop_after_attempt(5), wait=wait_random_exponential(max=10), reraise=True) # retry if exception
99
  def get_response(self, temperature: float, use_json: bool, timeout: float, cache: bool) -> str:
100
+ # hf_api_token =
101
+ client = InferenceClient(model=self.name, token=os.getenv('HF_API_TOKEN'), timeout=timeout)
102
  # client = InferenceClient(model=self.name, token=random.choice(HF_API_TOKENS), timeout=timeout)
103
  if not cache:
104
  client.headers["x-use-cache"] = "0"
 
161
  class NousHermesModel(HFAPIModel):
162
 
163
  def __init__(self, system_prompt: str, model_name: str = 'NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO') -> None:
164
+ assert model_name in SUPPORTED_NOUS_MODELS, 'Model not supported'
165
  super().__init__(model_name, system_prompt)
166
 
167
  def format_messages(self) -> str:
 
205
  r += f'{content}</s>'
206
  else:
207
  raise ValueError
208
+ return r