sentiment / goodies /synth.py
dejanseo's picture
Upload 4 files
de55574 verified
raw
history blame
2.25 kB
import os
import csv
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# Load the model and tokenizer from the local directory
model_path = "C:\\models\\llama-3-8b-Instruct-bnb-4bit"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path)
# Parameters for generating data
num_samples = 100000
output_file = 'raw_data.csv'
# Sentiment labels as textual descriptions
sentiment_labels = {
0: "very positive",
1: "positive",
2: "somewhat positive",
3: "neutral",
4: "somewhat negative",
5: "negative",
6: "very negative"
}
# Ensure output CSV file exists and create if not, with headers
if not os.path.exists(output_file):
with open(output_file, 'w', newline='', encoding='utf-8') as file:
writer = csv.writer(file)
writer.writerow(['text', 'label']) # Writing the header
# Append raw generated data to the CSV file
for i in range(num_samples):
label = i % len(sentiment_labels) # Ensure labels cycle properly from 0 to 6
sentiment = sentiment_labels[label]
# Encode the prompt with dynamic sentiment label
prompt = f"Generate a short article on a random topic and writing style, ensuring the sentiment is {sentiment}. Write nothing but the article text. Do not include the sentiment in the text of the article."
print(f"Generating sample {i+1}/{num_samples}: {prompt}") # Output the prompt to console for verification
input_ids = tokenizer.encode(prompt, return_tensors='pt')
# Generate response from the model
output = model.generate(input_ids, max_length=200, do_sample=True, top_k=50, top_p=0.95, temperature=0.7)
response = tokenizer.decode(output[0], skip_special_tokens=True)
# Get only the new tokens generated by the model
new_tokens = response[len(prompt):].strip()
# Append the raw generated text and numeric label to the CSV file
with open(output_file, 'a', newline='', encoding='utf-8') as file:
writer = csv.writer(file)
writer.writerow([new_tokens, label]) # Writing each row as it's generated
print(f"Data generation completed. Data appended to {output_file}")