Spaces:
Runtime error
Runtime error
from text_generation import Client | |
import os | |
from dotenv import load_dotenv | |
load_dotenv() | |
PAPERSPACE_IP = os.getenv("PAPERSPACE_IP") | |
client = Client(PAPERSPACE_IP) | |
def generate_text(input_text, max_new_tokens=20, temperature=1): | |
return client.generate(input_text, max_new_tokens=max_new_tokens, temperature=temperature).generated_text | |
def generate_multi_text(input_text, file_path, max_new_tokens=20, temperature=1, out_path=None, earlystop = None): | |
with open(file_path, "r") as file: | |
rows = file.readlines() | |
if earlystop is not None: | |
rows = rows[:earlystop] | |
multi_turns = [formatter(row.strip()) for row in rows] | |
print("You are playing " + str(len(multi_turns)) + " turns.") | |
generated_text = [] | |
with open(out_path, "w") as file: | |
for i, turn in enumerate(multi_turns): | |
single_turn_resp = generate_text(input_text+turn, | |
max_new_tokens=max_new_tokens, temperature=temperature) | |
generated_text.append(single_turn_resp) | |
file.write(f"Turn {i+1}: {single_turn_resp}\n") | |
print(turn) | |
print(single_turn_resp) | |
print("-----------") | |
return generated_text | |
def read_text_file(file_path): | |
with open(file_path, 'r') as file: | |
return file.read() | |
def formatter(user_prompt): | |
return f"[User]: {user_prompt.strip()} \n [You]: \n" | |
def main(): | |
cwd = os.getcwd() | |
input_text = read_text_file(os.path.join(cwd, 'utils/prompts/prompt_attitude.txt')) | |
# user_turn = read_text_file(os.path.join(cwd, '../finetune/data/turns/conversation_nothing.txt')) | |
max_new_tokens = 40 | |
temperature = 0.3 | |
multi_path = os.path.join(cwd,'inappropriate.txt') | |
out_path = os.path.join(cwd, f'utils/user_turns/multi_turns_conversation_t{temperature}_m{max_new_tokens}_promptatt_mistral_inapp.txt') | |
generated_text = generate_multi_text(input_text, multi_path, max_new_tokens, temperature, out_path) | |
# print(input_text+user_turn) | |
# generate_text_resp = generate_text(input_text+user_turn,max_new_tokens ) | |
# print(generate_text_resp) | |
if __name__ == "__main__": | |
main() | |