File size: 1,813 Bytes
e6e0999
 
acdb07d
fb02173
e6e0999
 
 
 
e6aecfd
 
443a13a
e6aecfd
2b8237c
 
 
 
 
ff1ad19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34a9542
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
---
library_name: transformers
license: apache-2.0
inference: false
---

# Model Card for Model ID

Generates possible search queries for a given product with title and dedscription. Can be used to synthetically generate search queries.

Input -> "Title: " + 《product_title》 + "Description: " + 《product_details》

## Development details
Model is trained with a novel adversarial Generator-Retriever framework. 

The details of the framework can be found [here](https://github.com/PraveenSH/adversarial-generator-retriever/blob/main/README.md).
Notebook with the code is available [here](https://github.com/PraveenSH/adversarial-generator-retriever/blob/main/generator_retriever.ipynb)

## Using the model
```python
from transformers import T5ForConditionalGeneration, T5Tokenizer
import torch

MODEL_ID = "prhegde/search-query-generator-ecommerce"
gen_tokenizer = T5Tokenizer.from_pretrained(MODEL_ID)
gen_model = T5ForConditionalGeneration.from_pretrained(MODEL_ID)

gen_model.eval()

prod_title = "home sweet home pine pallet wall décor"
prod_desc = "decorate your home with this rustic wood , which is made from high-quality pine pallets . this creates a beautiful rustic look for the kitchen , bedroom , or living room — great gift idea for any occasion ; perfect for holidays , birthdays , or game days"
input_sequence = "Title: " + prod_title + " - Description: " + prod_desc

input_ids = gen_tokenizer(input_sequence, return_tensors="pt").input_ids
print(f'Input: {input_sequence}')

nsent = 4
with torch.no_grad():
    for i in range(nsent):
        output = gen_model.generate(input_ids, max_length=35, num_beams=1, do_sample=True, repetition_penalty=1.8)
        
        target_sequence = gen_tokenizer.decode(output[0], skip_special_tokens=True)
        print(f'Target: {target_sequence}')