|
---
|
|
license: llama2
|
|
base_model:
|
|
- meta-llama/Llama-2-7b
|
|
pipeline_tag: text-generation
|
|
library_name: transformers
|
|
---
|
|
|
|
The code below shows how this Buyer Persona generator can be used. |
|
|
|
This model was developed for [MarketFit.ai](https://danlou.co/marketfitai). |
|
|
|
```python |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
from tqdm import tqdm |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
model_id = "danlou/persona-generator-llama-2-7b-qlora-merged" |
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.float16) |
|
|
|
|
|
def chunks(lst, n): |
|
"""Yield successive n-sized chunks from lst.""" |
|
for i in range(0, len(lst), n): |
|
yield lst[i:i + n] |
|
|
|
|
|
def parse_outputs(output_text): |
|
|
|
try: |
|
output_lns = output_text.split('\n') |
|
assert len(output_lns) == 2 |
|
assert len(output_lns[0].split(',')) == 2 |
|
assert len(output_lns[1]) > 16 |
|
|
|
name, age = [s.strip() for s in output_lns[0].split(',')] |
|
desc = output_lns[1].strip() |
|
|
|
except AssertionError: |
|
raise Exception('Malformed output.') |
|
|
|
try: |
|
age = int(age) |
|
except ValueError: |
|
raise Exception('Malformed output (age).') |
|
|
|
return {'name': name, 'age': age, 'description': desc} |
|
|
|
|
|
|
|
def generate_personas(product, n=1, batch_size=32, parse=True): |
|
|
|
prompt = f"### Instruction:\nDescribe the ideal persona for this product:\n{product}\n\n### Response:\n" |
|
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) |
|
|
|
personas = [] |
|
with tqdm(total=n) as pbar: |
|
for batch in chunks(range(n), batch_size): |
|
outputs = model.generate(input_ids, |
|
do_sample=True, |
|
num_beams=1, |
|
num_return_sequences=len(batch), |
|
max_length=512, |
|
min_length=32, |
|
temperature=0.9) |
|
|
|
for output_ids in outputs: |
|
output_decoded = tokenizer.decode(output_ids, skip_special_tokens=True) |
|
output_decoded = output_decoded[len(prompt):].strip() |
|
|
|
try: |
|
if parse: |
|
personas.append(parse_outputs(output_decoded)) |
|
else: |
|
personas.append(output_decoded) |
|
except Exception as e: |
|
print(e) |
|
continue |
|
|
|
pbar.update(len(batch)) |
|
|
|
return personas |
|
|
|
|
|
product = "Koonie 10000mAh Rechargeable Desk Fan, 8-Inch Battery Operated Clip on Fan, USB Fan, 4 Speeds, Strong Airflow, Sturdy Clamp for Golf Cart Office Desk Outdoor Travel Camping Tent Gym Treadmill, Black (USB Gadgets > USB Fans)" |
|
personas = generate_personas(product, n=3) |
|
|
|
for e in personas: |
|
print(e) |
|
|
|
# Persona 1 - The yoga instructor |
|
# {'name': 'Sarah', 'age': 28, 'description': 'Yoga instructor who is passionate about health and fitness. She works from a home studio where she also practices yoga and meditation. Sarah values products that are eco-friendly and sustainable. She loves products that are versatile and can be used for different purposes. Sarah is looking for a product that is durable and can withstand frequent use. She values products that are stylish and aesthetically pleasing.'} |
|
# Persona 2 - The golf enthusiast |
|
#{'name': 'Sophia', 'age': 60, 'description': "Golf enthusiast. Sophia spends most of her weekends on the golf course, and she needs a fan that she can carry around in her golf cart. She needs a fan that's lightweight, easy to clip on, and has a long battery life. She also wants a fan that's affordable, especially since she plays at different courses."} |
|
# Persona 3 - The truck driver |
|
# {'name': 'Mike', 'age': 32, 'description': "Truck driver who spends most of his day on the road. The cab of his truck can get hot and stuffy, and Mike needs a fan that can keep him comfortable and alert while he's driving. He needs a fan that's easy to install and adjust, so he can keep it on his dashboard and direct the airflow where he needs it most."} |
|
``` |