import os import csv from transformers import AutoModelForCausalLM, AutoTokenizer import torch # Load the model and tokenizer from the local directory model_path = "C:\\models\\llama-3-8b-Instruct-bnb-4bit" tokenizer = AutoTokenizer.from_pretrained(model_path) model = AutoModelForCausalLM.from_pretrained(model_path) # Parameters for generating data num_samples = 100000 output_file = 'raw_data.csv' # Sentiment labels as textual descriptions sentiment_labels = { 0: "very positive", 1: "positive", 2: "somewhat positive", 3: "neutral", 4: "somewhat negative", 5: "negative", 6: "very negative" } # Ensure output CSV file exists and create if not, with headers if not os.path.exists(output_file): with open(output_file, 'w', newline='', encoding='utf-8') as file: writer = csv.writer(file) writer.writerow(['text', 'label']) # Writing the header # Append raw generated data to the CSV file for i in range(num_samples): label = i % len(sentiment_labels) # Ensure labels cycle properly from 0 to 6 sentiment = sentiment_labels[label] # Encode the prompt with dynamic sentiment label prompt = f"Generate a short article on a random topic and writing style, ensuring the sentiment is {sentiment}. Write nothing but the article text. Do not include the sentiment in the text of the article." print(f"Generating sample {i+1}/{num_samples}: {prompt}") # Output the prompt to console for verification input_ids = tokenizer.encode(prompt, return_tensors='pt') # Generate response from the model output = model.generate(input_ids, max_length=200, do_sample=True, top_k=50, top_p=0.95, temperature=0.7) response = tokenizer.decode(output[0], skip_special_tokens=True) # Get only the new tokens generated by the model new_tokens = response[len(prompt):].strip() # Append the raw generated text and numeric label to the CSV file with open(output_file, 'a', newline='', encoding='utf-8') as file: writer = csv.writer(file) writer.writerow([new_tokens, label]) # Writing each row as it's generated print(f"Data generation completed. Data appended to {output_file}")