Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	
		Wonderplex
		
	commited on
		
		
					Commit 
							
							·
						
						8b07d8c
	
1
								Parent(s):
							
							f1c7954
								
fixed parsing errors and extra_info (#56)
Browse files- app.py +35 -5
- message_classes.py +1 -1
- sotopia_pi_generate.py +49 -49
- start_app.sh +4 -0
- utils.py +5 -5
    	
        app.py
    CHANGED
    
    | @@ -1,6 +1,7 @@ | |
| 1 | 
             
            import os
         | 
| 2 | 
             
            from collections import defaultdict
         | 
| 3 | 
             
            import json
         | 
|  | |
| 4 |  | 
| 5 | 
             
            import gradio as gr
         | 
| 6 |  | 
| @@ -25,6 +26,21 @@ RELATIONSHIP_PROFILES = "profiles/relationship_profiles.jsonl" | |
| 25 |  | 
| 26 | 
             
            ACTION_TYPES = ['none', 'action', 'non-verbal communication', 'speak', 'leave']
         | 
| 27 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 28 | 
             
            @cache
         | 
| 29 | 
             
            def get_sotopia_profiles(env_file=ENVIRONMENT_PROFILES, agent_file=AGENT_PROFILES, relationship_file=RELATIONSHIP_PROFILES):
         | 
| 30 | 
             
                with open(env_file, 'r') as f:
         | 
| @@ -126,13 +142,27 @@ def create_bot_info(bot_agent_dropdown): | |
| 126 | 
             
                return gr.Textbox(label="Bot Agent Profile", lines=4, value=text)
         | 
| 127 |  | 
| 128 | 
             
            def create_user_goal(environment_dropdown):
         | 
| 129 | 
            -
             | 
| 130 | 
            -
             | 
| 131 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 132 |  | 
| 133 | 
             
            def create_bot_goal(environment_dropdown):
         | 
| 134 | 
             
                _, environment_dict, _, _ = get_sotopia_profiles()
         | 
| 135 | 
             
                text = environment_dict[environment_dropdown].agent_goals[1]
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 136 | 
             
                return gr.Textbox(label="Bot Agent Goal", lines=4, value=text)
         | 
| 137 |  | 
| 138 | 
             
            def sotopia_info_accordion(accordion_visible=True):
         | 
| @@ -147,7 +177,7 @@ def sotopia_info_accordion(accordion_visible=True): | |
| 147 | 
             
                            interactive=True,
         | 
| 148 | 
             
                        )
         | 
| 149 | 
             
                        model_name_dropdown = gr.Dropdown(
         | 
| 150 | 
            -
                            choices= | 
| 151 | 
             
                            value=DEFAULT_MODEL_SELECTION,
         | 
| 152 | 
             
                            interactive=True,
         | 
| 153 | 
             
                            label="Model Selection"
         | 
| @@ -215,7 +245,7 @@ def chat_tab(): | |
| 215 |  | 
| 216 | 
             
                    context = get_context_prompt(bot_agent, user_agent, environment)
         | 
| 217 | 
             
                    dialogue_history, next_turn_idx = dialogue_history_prompt(message, history, user_agent, bot_agent)
         | 
| 218 | 
            -
                    prompt_history = f"{context} | 
| 219 | 
             
                    agent_action = generate_action(model_selection, prompt_history, next_turn_idx, ACTION_TYPES, bot_agent.name, TEMPERATURE)
         | 
| 220 | 
             
                    return agent_action.to_natural_language()
         | 
| 221 |  | 
|  | |
| 1 | 
             
            import os
         | 
| 2 | 
             
            from collections import defaultdict
         | 
| 3 | 
             
            import json
         | 
| 4 | 
            +
            from typing import Literal
         | 
| 5 |  | 
| 6 | 
             
            import gradio as gr
         | 
| 7 |  | 
|  | |
| 26 |  | 
| 27 | 
             
            ACTION_TYPES = ['none', 'action', 'non-verbal communication', 'speak', 'leave']
         | 
| 28 |  | 
| 29 | 
            +
            MODEL_OPTIONS = [
         | 
| 30 | 
            +
                "gpt-3.5-turbo",
         | 
| 31 | 
            +
                "gpt-4",
         | 
| 32 | 
            +
                "gpt-4-turbo",
         | 
| 33 | 
            +
                "cmu-lti/sotopia-pi-mistral-7b-BC_SR",
         | 
| 34 | 
            +
                "cmu-lti/sotopia-pi-mistral-7b-BC_SR_4bit",
         | 
| 35 | 
            +
                "mistralai/Mistral-7B-Instruct-v0.1"
         | 
| 36 | 
            +
                # "mistralai/Mixtral-8x7B-Instruct-v0.1",
         | 
| 37 | 
            +
                # "togethercomputer/llama-2-7b-chat",
         | 
| 38 | 
            +
                # "togethercomputer/llama-2-70b-chat",
         | 
| 39 | 
            +
                # "togethercomputer/mpt-30b-chat",
         | 
| 40 | 
            +
                # "together_ai/togethercomputer/llama-2-7b-chat",
         | 
| 41 | 
            +
                # "together_ai/togethercomputer/falcon-7b-instruct",
         | 
| 42 | 
            +
            ]
         | 
| 43 | 
            +
             | 
| 44 | 
             
            @cache
         | 
| 45 | 
             
            def get_sotopia_profiles(env_file=ENVIRONMENT_PROFILES, agent_file=AGENT_PROFILES, relationship_file=RELATIONSHIP_PROFILES):
         | 
| 46 | 
             
                with open(env_file, 'r') as f:
         | 
|  | |
| 142 | 
             
                return gr.Textbox(label="Bot Agent Profile", lines=4, value=text)
         | 
| 143 |  | 
| 144 | 
             
            def create_user_goal(environment_dropdown):
         | 
| 145 | 
            +
                _, environment_dict, _, _ = get_sotopia_profiles()
         | 
| 146 | 
            +
                text = environment_dict[environment_dropdown].agent_goals[0]
         | 
| 147 | 
            +
                text = text.replace('(', '').replace(')', '')
         | 
| 148 | 
            +
                if "<extra_info>" in text:
         | 
| 149 | 
            +
                    text = text.replace("<extra_info>", "\n\n")
         | 
| 150 | 
            +
                    text = text.replace("</extra_info>", "\n")
         | 
| 151 | 
            +
                if "<strategy_hint>" in text:
         | 
| 152 | 
            +
                    text = text.replace("<strategy_hint>", "\n\n")
         | 
| 153 | 
            +
                    text = text.replace("</strategy_hint>", "\n")
         | 
| 154 | 
            +
                return gr.Textbox(label="User Agent Goal", lines=4, value=text)
         | 
| 155 |  | 
| 156 | 
             
            def create_bot_goal(environment_dropdown):
         | 
| 157 | 
             
                _, environment_dict, _, _ = get_sotopia_profiles()
         | 
| 158 | 
             
                text = environment_dict[environment_dropdown].agent_goals[1]
         | 
| 159 | 
            +
                text = text.replace('(', '').replace(')', '')
         | 
| 160 | 
            +
                if "<extra_info>" in text:
         | 
| 161 | 
            +
                    text = text.replace("<extra_info>", "\n\n")
         | 
| 162 | 
            +
                    text = text.replace("</extra_info>", "\n")
         | 
| 163 | 
            +
                if "<strategy_hint>" in text:
         | 
| 164 | 
            +
                    text = text.replace("<strategy_hint>", "\n\n")
         | 
| 165 | 
            +
                    text = text.replace("</strategy_hint>", "\n")
         | 
| 166 | 
             
                return gr.Textbox(label="Bot Agent Goal", lines=4, value=text)
         | 
| 167 |  | 
| 168 | 
             
            def sotopia_info_accordion(accordion_visible=True):
         | 
|  | |
| 177 | 
             
                            interactive=True,
         | 
| 178 | 
             
                        )
         | 
| 179 | 
             
                        model_name_dropdown = gr.Dropdown(
         | 
| 180 | 
            +
                            choices=MODEL_OPTIONS,
         | 
| 181 | 
             
                            value=DEFAULT_MODEL_SELECTION,
         | 
| 182 | 
             
                            interactive=True,
         | 
| 183 | 
             
                            label="Model Selection"
         | 
|  | |
| 245 |  | 
| 246 | 
             
                    context = get_context_prompt(bot_agent, user_agent, environment)
         | 
| 247 | 
             
                    dialogue_history, next_turn_idx = dialogue_history_prompt(message, history, user_agent, bot_agent)
         | 
| 248 | 
            +
                    prompt_history = f"{context}{dialogue_history}"
         | 
| 249 | 
             
                    agent_action = generate_action(model_selection, prompt_history, next_turn_idx, ACTION_TYPES, bot_agent.name, TEMPERATURE)
         | 
| 250 | 
             
                    return agent_action.to_natural_language()
         | 
| 251 |  | 
    	
        message_classes.py
    CHANGED
    
    | @@ -120,7 +120,7 @@ class AgentAction(Message): | |
| 120 | 
             
                        case "none":
         | 
| 121 | 
             
                            return "did nothing"
         | 
| 122 | 
             
                        case "speak":
         | 
| 123 | 
            -
                            return f | 
| 124 | 
             
                        case "non-verbal communication":
         | 
| 125 | 
             
                            return f"[{self.action_type}] {self.argument}"
         | 
| 126 | 
             
                        case "action":
         | 
|  | |
| 120 | 
             
                        case "none":
         | 
| 121 | 
             
                            return "did nothing"
         | 
| 122 | 
             
                        case "speak":
         | 
| 123 | 
            +
                            return f"{self.argument}"
         | 
| 124 | 
             
                        case "non-verbal communication":
         | 
| 125 | 
             
                            return f"[{self.action_type}] {self.argument}"
         | 
| 126 | 
             
                        case "action":
         | 
    	
        sotopia_pi_generate.py
    CHANGED
    
    | @@ -28,6 +28,9 @@ from utils import format_docstring | |
| 28 | 
             
            from langchain_callback_handler import LoggingCallbackHandler
         | 
| 29 |  | 
| 30 | 
             
            HF_TOKEN_KEY_FILE="./hf_token.key"
         | 
|  | |
|  | |
|  | |
| 31 |  | 
| 32 | 
             
            OutputType = TypeVar("OutputType", bound=object)
         | 
| 33 | 
             
            log = logging.getLogger("generate")
         | 
| @@ -44,59 +47,54 @@ def generate_action( | |
| 44 | 
             
                """
         | 
| 45 | 
             
                Using langchain to generate an example episode
         | 
| 46 | 
             
                """
         | 
| 47 | 
            -
                try:
         | 
| 48 | 
             
                    # Normal case, model as agent
         | 
| 49 | 
            -
             | 
| 50 | 
            -
             | 
| 51 | 
            -
             | 
| 52 | 
            -
             | 
| 53 | 
            -
             | 
| 54 | 
            -
             | 
| 55 | 
            -
             | 
| 56 | 
            -
             | 
| 57 | 
            -
             | 
| 58 | 
            -
             | 
| 59 |  | 
| 60 | 
            -
             | 
| 61 | 
            -
             | 
| 62 | 
            -
             | 
| 63 | 
            -
             | 
| 64 | 
            -
             | 
| 65 | 
            -
             | 
| 66 | 
            -
             | 
| 67 | 
            -
             | 
| 68 | 
            -
             | 
| 69 | 
            -
             | 
| 70 | 
            -
             | 
| 71 | 
            -
             | 
| 72 | 
            -
             | 
| 73 | 
            -
             | 
| 74 | 
            -
             | 
| 75 | 
            -
             | 
| 76 | 
            -
                except Exception:
         | 
| 77 | 
            -
             | 
|  | |
| 78 |  | 
| 79 | 
             
            @cache
         | 
| 80 | 
            -
            def prepare_model(model_name | 
| 81 | 
             
                compute_type = torch.float16
         | 
| 82 | 
            -
                if os.path.exists(hf_token_key_file):
         | 
| 83 | 
            -
                    with open (hf_token_key_file, 'r') as f:
         | 
| 84 | 
            -
                        hf_token = f.read().strip()
         | 
| 85 | 
            -
                else:
         | 
| 86 | 
            -
                    hf_token = os.environ["HF_TOKEN"]
         | 
| 87 |  | 
| 88 | 
             
                if model_name == 'cmu-lti/sotopia-pi-mistral-7b-BC_SR':
         | 
| 89 | 
            -
                    tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1",  | 
| 90 | 
             
                    model = AutoModelForCausalLM.from_pretrained(
         | 
| 91 | 
             
                    "mistralai/Mistral-7B-Instruct-v0.1",
         | 
| 92 | 
             
                    cache_dir="./.cache",
         | 
| 93 | 
            -
                    device_map='cuda' | 
| 94 | 
            -
                    token=hf_token
         | 
| 95 | 
             
                    )
         | 
| 96 | 
             
                    model = PeftModel.from_pretrained(model, model_name).to("cuda")
         | 
| 97 |  | 
| 98 | 
             
                elif model_name == 'cmu-lti/sotopia-pi-mistral-7b-BC_SR_4bit':
         | 
| 99 | 
            -
                    tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1",  | 
| 100 | 
             
                    model = AutoModelForCausalLM.from_pretrained(
         | 
| 101 | 
             
                    "mistralai/Mistral-7B-Instruct-v0.1",
         | 
| 102 | 
             
                    cache_dir="./.cache",
         | 
| @@ -106,18 +104,17 @@ def prepare_model(model_name, hf_token_key_file=HF_TOKEN_KEY_FILE): | |
| 106 | 
             
                        bnb_4bit_use_double_quant=True,
         | 
| 107 | 
             
                        bnb_4bit_quant_type="nf4",
         | 
| 108 | 
             
                        bnb_4bit_compute_dtype=compute_type,
         | 
| 109 | 
            -
                        ) | 
| 110 | 
            -
                    token=hf_token
         | 
| 111 | 
             
                    )
         | 
| 112 | 
            -
                    model = PeftModel.from_pretrained(model, model_name).to("cuda")
         | 
| 113 |  | 
| 114 | 
             
                elif model_name == 'mistralai/Mistral-7B-Instruct-v0.1':
         | 
| 115 | 
            -
                    tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1",  | 
|  | |
| 116 | 
             
                    model = AutoModelForCausalLM.from_pretrained(
         | 
| 117 | 
             
                    "mistralai/Mistral-7B-Instruct-v0.1",
         | 
| 118 | 
             
                    cache_dir="./.cache",
         | 
| 119 | 
            -
                    device_map='cuda' | 
| 120 | 
            -
                    token=hf_token
         | 
| 121 | 
             
                    )
         | 
| 122 |  | 
| 123 | 
             
                else:
         | 
| @@ -146,7 +143,6 @@ def obtain_chain_hf( | |
| 146 | 
             
                                return_full_text=False, 
         | 
| 147 | 
             
                                do_sample=True,
         | 
| 148 | 
             
                                num_beams=3,
         | 
| 149 | 
            -
                                length_penalty=-1.0,
         | 
| 150 | 
             
                                )
         | 
| 151 | 
             
                hf = HuggingFacePipeline(pipeline=pipe)
         | 
| 152 | 
             
                chain = LLMChain(llm=hf, prompt=chat_prompt_template)
         | 
| @@ -171,6 +167,8 @@ def generate( | |
| 171 | 
             
                    input_values["format_instructions"] = output_parser.get_format_instructions()
         | 
| 172 | 
             
                result = chain.predict([logging_handler], **input_values)
         | 
| 173 | 
             
                prompt = logging_handler.retrive_prompt()
         | 
|  | |
|  | |
| 174 | 
             
                try:
         | 
| 175 | 
             
                    parsed_result = output_parser.parse(result)
         | 
| 176 | 
             
                except KeyboardInterrupt:
         | 
| @@ -183,6 +181,7 @@ def generate( | |
| 183 | 
             
                    reformat_parsed_result = format_bad_output(
         | 
| 184 | 
             
                        result, format_instructions=output_parser.get_format_instructions()
         | 
| 185 | 
             
                    )
         | 
|  | |
| 186 | 
             
                    parsed_result = output_parser.parse(reformat_parsed_result)
         | 
| 187 | 
             
                log.info(f"Generated result: {parsed_result}")
         | 
| 188 | 
             
                return parsed_result
         | 
| @@ -223,7 +222,7 @@ def obtain_chain( | |
| 223 | 
             
                """
         | 
| 224 | 
             
                Using langchain to sample profiles for participants
         | 
| 225 | 
             
                """
         | 
| 226 | 
            -
                if model_name in ["cmu-lti/sotopia-pi-mistral-7b-BC_SR", "cmu-lti/sotopia-pi-mistral-7b-BC_SR_4bit"]:
         | 
| 227 | 
             
                    return obtain_chain_hf(
         | 
| 228 | 
             
                        model_name=model_name,
         | 
| 229 | 
             
                        template=template,
         | 
| @@ -247,10 +246,11 @@ def obtain_chain( | |
| 247 | 
             
                return chain
         | 
| 248 |  | 
| 249 | 
             
            def _return_fixed_model_version(model_name: str) -> str:
         | 
| 250 | 
            -
                 | 
| 251 | 
             
                    "gpt-3.5-turbo": "gpt-3.5-turbo-0613",
         | 
| 252 | 
             
                    "gpt-3.5-turbo-finetuned": "ft:gpt-3.5-turbo-0613:academicscmu::8nY2zgdt",
         | 
| 253 | 
             
                    "gpt-3.5-turbo-ft-MF": "ft:gpt-3.5-turbo-0613:academicscmu::8nuER4bO",
         | 
| 254 | 
             
                    "gpt-4": "gpt-4-0613",
         | 
| 255 | 
             
                    "gpt-4-turbo": "gpt-4-1106-preview",
         | 
| 256 | 
            -
                } | 
|  | 
|  | |
| 28 | 
             
            from langchain_callback_handler import LoggingCallbackHandler
         | 
| 29 |  | 
| 30 | 
             
            HF_TOKEN_KEY_FILE="./hf_token.key"
         | 
| 31 | 
            +
            if os.path.exists(HF_TOKEN_KEY_FILE):
         | 
| 32 | 
            +
                with open(HF_TOKEN_KEY_FILE, "r") as f:
         | 
| 33 | 
            +
                    os.environ["HF_TOKEN"] = f.read().strip()
         | 
| 34 |  | 
| 35 | 
             
            OutputType = TypeVar("OutputType", bound=object)
         | 
| 36 | 
             
            log = logging.getLogger("generate")
         | 
|  | |
| 47 | 
             
                """
         | 
| 48 | 
             
                Using langchain to generate an example episode
         | 
| 49 | 
             
                """
         | 
| 50 | 
            +
                # try:
         | 
| 51 | 
             
                    # Normal case, model as agent
         | 
| 52 | 
            +
                template = """
         | 
| 53 | 
            +
                    Imagine you are {agent}, your task is to act/speak as {agent} would, keeping in mind {agent}'s social goal.
         | 
| 54 | 
            +
                    You can find {agent}'s goal (or background) in the 'Here is the context of the interaction' field.
         | 
| 55 | 
            +
                    Note that {agent}'s goal is only visible to you.
         | 
| 56 | 
            +
                    You should try your best to achieve {agent}'s goal in a way that align with their character traits.
         | 
| 57 | 
            +
                    Additionally, maintaining the conversation's naturalness and realism is essential (e.g., do not repeat what other people has already said before).\n
         | 
| 58 | 
            +
                    {history}.
         | 
| 59 | 
            +
                    You are at Turn #{turn_number}. Your available action types are
         | 
| 60 | 
            +
                    {action_list}.
         | 
| 61 | 
            +
                    Note: You can "leave" this conversation if 1. you have achieved your social goals, 2. this conversation makes you uncomfortable, 3. you find it uninteresting/you lose your patience, 4. or for other reasons you want to leave.
         | 
| 62 |  | 
| 63 | 
            +
                    Please only generate a JSON string including the action type and the argument.
         | 
| 64 | 
            +
                    Your action should follow the given format:
         | 
| 65 | 
            +
                    {format_instructions}
         | 
| 66 | 
            +
                """
         | 
| 67 | 
            +
                return generate(
         | 
| 68 | 
            +
                    model_name=model_name,
         | 
| 69 | 
            +
                    template=template,
         | 
| 70 | 
            +
                    input_values=dict(
         | 
| 71 | 
            +
                        agent=agent,
         | 
| 72 | 
            +
                        turn_number=str(turn_number),
         | 
| 73 | 
            +
                        history=history,
         | 
| 74 | 
            +
                        action_list=" ".join(action_types),
         | 
| 75 | 
            +
                    ),
         | 
| 76 | 
            +
                    output_parser=PydanticOutputParser(pydantic_object=AgentAction),
         | 
| 77 | 
            +
                    temperature=temperature,
         | 
| 78 | 
            +
                )
         | 
| 79 | 
            +
                # except Exception as e:
         | 
| 80 | 
            +
                #     print(e)
         | 
| 81 | 
            +
                #     return AgentAction(action_type="none", argument="")
         | 
| 82 |  | 
| 83 | 
             
            @cache
         | 
| 84 | 
            +
            def prepare_model(model_name):
         | 
| 85 | 
             
                compute_type = torch.float16
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 86 |  | 
| 87 | 
             
                if model_name == 'cmu-lti/sotopia-pi-mistral-7b-BC_SR':
         | 
| 88 | 
            +
                    tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", model_max_length=4096)
         | 
| 89 | 
             
                    model = AutoModelForCausalLM.from_pretrained(
         | 
| 90 | 
             
                    "mistralai/Mistral-7B-Instruct-v0.1",
         | 
| 91 | 
             
                    cache_dir="./.cache",
         | 
| 92 | 
            +
                    device_map='cuda'
         | 
|  | |
| 93 | 
             
                    )
         | 
| 94 | 
             
                    model = PeftModel.from_pretrained(model, model_name).to("cuda")
         | 
| 95 |  | 
| 96 | 
             
                elif model_name == 'cmu-lti/sotopia-pi-mistral-7b-BC_SR_4bit':
         | 
| 97 | 
            +
                    tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", model_max_length=4096)
         | 
| 98 | 
             
                    model = AutoModelForCausalLM.from_pretrained(
         | 
| 99 | 
             
                    "mistralai/Mistral-7B-Instruct-v0.1",
         | 
| 100 | 
             
                    cache_dir="./.cache",
         | 
|  | |
| 104 | 
             
                        bnb_4bit_use_double_quant=True,
         | 
| 105 | 
             
                        bnb_4bit_quant_type="nf4",
         | 
| 106 | 
             
                        bnb_4bit_compute_dtype=compute_type,
         | 
| 107 | 
            +
                        )
         | 
|  | |
| 108 | 
             
                    )
         | 
| 109 | 
            +
                    model = PeftModel.from_pretrained(model, model_name[0:-5]).to("cuda")
         | 
| 110 |  | 
| 111 | 
             
                elif model_name == 'mistralai/Mistral-7B-Instruct-v0.1':
         | 
| 112 | 
            +
                    tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", model_max_length=4096)
         | 
| 113 | 
            +
                    tokenizer.model_max_length = 4096
         | 
| 114 | 
             
                    model = AutoModelForCausalLM.from_pretrained(
         | 
| 115 | 
             
                    "mistralai/Mistral-7B-Instruct-v0.1",
         | 
| 116 | 
             
                    cache_dir="./.cache",
         | 
| 117 | 
            +
                    device_map='cuda'
         | 
|  | |
| 118 | 
             
                    )
         | 
| 119 |  | 
| 120 | 
             
                else:
         | 
|  | |
| 143 | 
             
                                return_full_text=False, 
         | 
| 144 | 
             
                                do_sample=True,
         | 
| 145 | 
             
                                num_beams=3,
         | 
|  | |
| 146 | 
             
                                )
         | 
| 147 | 
             
                hf = HuggingFacePipeline(pipeline=pipe)
         | 
| 148 | 
             
                chain = LLMChain(llm=hf, prompt=chat_prompt_template)
         | 
|  | |
| 167 | 
             
                    input_values["format_instructions"] = output_parser.get_format_instructions()
         | 
| 168 | 
             
                result = chain.predict([logging_handler], **input_values)
         | 
| 169 | 
             
                prompt = logging_handler.retrive_prompt()
         | 
| 170 | 
            +
                print(f"Prompt:\n {prompt}")
         | 
| 171 | 
            +
                print(f"Result:\n {result}")
         | 
| 172 | 
             
                try:
         | 
| 173 | 
             
                    parsed_result = output_parser.parse(result)
         | 
| 174 | 
             
                except KeyboardInterrupt:
         | 
|  | |
| 181 | 
             
                    reformat_parsed_result = format_bad_output(
         | 
| 182 | 
             
                        result, format_instructions=output_parser.get_format_instructions()
         | 
| 183 | 
             
                    )
         | 
| 184 | 
            +
                    print(f"Reformatted result:\n {reformat_parsed_result}")
         | 
| 185 | 
             
                    parsed_result = output_parser.parse(reformat_parsed_result)
         | 
| 186 | 
             
                log.info(f"Generated result: {parsed_result}")
         | 
| 187 | 
             
                return parsed_result
         | 
|  | |
| 222 | 
             
                """
         | 
| 223 | 
             
                Using langchain to sample profiles for participants
         | 
| 224 | 
             
                """
         | 
| 225 | 
            +
                if model_name in ["cmu-lti/sotopia-pi-mistral-7b-BC_SR", "cmu-lti/sotopia-pi-mistral-7b-BC_SR_4bit", "mistralai/Mistral-7B-Instruct-v0.1"]:
         | 
| 226 | 
             
                    return obtain_chain_hf(
         | 
| 227 | 
             
                        model_name=model_name,
         | 
| 228 | 
             
                        template=template,
         | 
|  | |
| 246 | 
             
                return chain
         | 
| 247 |  | 
| 248 | 
             
            def _return_fixed_model_version(model_name: str) -> str:
         | 
| 249 | 
            +
                model_version_map = {
         | 
| 250 | 
             
                    "gpt-3.5-turbo": "gpt-3.5-turbo-0613",
         | 
| 251 | 
             
                    "gpt-3.5-turbo-finetuned": "ft:gpt-3.5-turbo-0613:academicscmu::8nY2zgdt",
         | 
| 252 | 
             
                    "gpt-3.5-turbo-ft-MF": "ft:gpt-3.5-turbo-0613:academicscmu::8nuER4bO",
         | 
| 253 | 
             
                    "gpt-4": "gpt-4-0613",
         | 
| 254 | 
             
                    "gpt-4-turbo": "gpt-4-1106-preview",
         | 
| 255 | 
            +
                }
         | 
| 256 | 
            +
                return model_version_map[model_name] if model_name in model_version_map else model_name
         | 
    	
        start_app.sh
    ADDED
    
    | @@ -0,0 +1,4 @@ | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            export OPENAI_API_KEY=$(cat openai_api.key)
         | 
| 2 | 
            +
            export HF_TOKEN=$(cat hf_token.key)
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            python app.py
         | 
    	
        utils.py
    CHANGED
    
    | @@ -44,11 +44,11 @@ def dialogue_history_prompt(message, history, user_agent, bot_agent): | |
| 44 | 
             
                    user_turn_idx = idx * 2
         | 
| 45 | 
             
                    bot_turn_idx = idx * 2 + 1
         | 
| 46 | 
             
                    if not bot_message.startswith("["): # if action type == speak, need to add 'said: ' to be consistent with the dialog prompt
         | 
| 47 | 
            -
                        bot_message =  | 
| 48 | 
            -
                    dialogue_history = f"{dialogue_history}\n\nTurn #{user_turn_idx} | 
| 49 | 
            -
                 | 
| 50 | 
            -
                dialogue_history = f"{dialogue_history}\n\nTurn #{ | 
| 51 | 
            -
                return dialogue_history,  | 
| 52 |  | 
| 53 | 
             
            def format_docstring(docstring: str) -> str:
         | 
| 54 | 
             
                """Format a docstring for use in a prompt template."""
         | 
|  | |
| 44 | 
             
                    user_turn_idx = idx * 2
         | 
| 45 | 
             
                    bot_turn_idx = idx * 2 + 1
         | 
| 46 | 
             
                    if not bot_message.startswith("["): # if action type == speak, need to add 'said: ' to be consistent with the dialog prompt
         | 
| 47 | 
            +
                        bot_message = 'said:"' + bot_message + '"'
         | 
| 48 | 
            +
                    dialogue_history = f"""{dialogue_history}\n\nTurn #{user_turn_idx} {user_agent.name} said: "{user_message}"\n\nTurn #{bot_turn_idx}: {bot_agent.name}: {bot_message}"""
         | 
| 49 | 
            +
                curr_turn_idx = len(history) * 2
         | 
| 50 | 
            +
                dialogue_history = f"""{dialogue_history}\n\nTurn #{curr_turn_idx} {user_agent.name} said: "{message}"\n"""
         | 
| 51 | 
            +
                return dialogue_history, curr_turn_idx + 1
         | 
| 52 |  | 
| 53 | 
             
            def format_docstring(docstring: str) -> str:
         | 
| 54 | 
             
                """Format a docstring for use in a prompt template."""
         | 
