internal-mistral / tgi_inference.py
eva-origin's picture
Upload folder using huggingface_hub
a8c6d33
from text_generation import Client
import os
from dotenv import load_dotenv
load_dotenv()
PAPERSPACE_IP = os.getenv("PAPERSPACE_IP")
client = Client(PAPERSPACE_IP)
def generate_text(input_text, max_new_tokens=20, temperature=1):
return client.generate(input_text, max_new_tokens=max_new_tokens, temperature=temperature).generated_text
def generate_multi_text(input_text, file_path, max_new_tokens=20, temperature=1, out_path=None, earlystop = None):
with open(file_path, "r") as file:
rows = file.readlines()
if earlystop is not None:
rows = rows[:earlystop]
multi_turns = [formatter(row.strip()) for row in rows]
print("You are playing " + str(len(multi_turns)) + " turns.")
generated_text = []
with open(out_path, "w") as file:
for i, turn in enumerate(multi_turns):
single_turn_resp = generate_text(input_text+turn,
max_new_tokens=max_new_tokens, temperature=temperature)
generated_text.append(single_turn_resp)
file.write(f"Turn {i+1}: {single_turn_resp}\n")
print(turn)
print(single_turn_resp)
print("-----------")
return generated_text
def read_text_file(file_path):
with open(file_path, 'r') as file:
return file.read()
def formatter(user_prompt):
return f"[User]: {user_prompt.strip()} \n [You]: \n"
def main():
cwd = os.getcwd()
input_text = read_text_file(os.path.join(cwd, 'utils/prompts/prompt_attitude.txt'))
# user_turn = read_text_file(os.path.join(cwd, '../finetune/data/turns/conversation_nothing.txt'))
max_new_tokens = 40
temperature = 0.3
multi_path = os.path.join(cwd,'inappropriate.txt')
out_path = os.path.join(cwd, f'utils/user_turns/multi_turns_conversation_t{temperature}_m{max_new_tokens}_promptatt_mistral_inapp.txt')
generated_text = generate_multi_text(input_text, multi_path, max_new_tokens, temperature, out_path)
# print(input_text+user_turn)
# generate_text_resp = generate_text(input_text+user_turn,max_new_tokens )
# print(generate_text_resp)
if __name__ == "__main__":
main()