File size: 3,712 Bytes
d667e5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
---
library_name: transformers
license: apache-2.0
datasets:
- Universal-NER/Pile-NER-type
- Universal-NER/Pile-NER-definition
language:
- en
base_model:
- google/flan-t5-small
pipeline_tag: text2text-generation
tags:
- named-entity-recognition
- generated_from_trainer
---
# flan-t5-small-ner

This model is a fine-tuned version of [google/flan-t5-small](https://huggingface.co/google/flan-t5-small)
on 200 000 random (text, entity) combinations from the 
[Universal-NER/Pile-NER-type](https://huggingface.co/datasets/Universal-NER/Pile-NER-type) and 
[Universal-NER/Pile-NER-definition](https://huggingface.co/datasets/Universal-NER/Pile-NER-definition) datasets.

- Loss: 0.5393
- Num Input Tokens Seen: 332318598

## Model Description

flan-t5-small-ner can extract entities of specific types or definitions from text such as person, company, school, technology, and many more.
It builds upon the FLAN-T5 architecture, which has strong performance across natural language processing tasks.

Example:

```python
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch

model_path = "agentlans/flan-t5-small-ner"
model = AutoModelForSeq2SeqLM.from_pretrained(model_path).to("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(model_path)

def custom_split(s): # Processes the output from the model
    parts = s.split("<|sep|>")
    if not s.endswith("<|end|>"):
        parts = parts[:-1] # If output is truncated, then don't include last item
    else:
        parts[-1] = parts[-1].replace("<|end|>", "") # Remove the marker tokens
    return [p.strip() for p in parts if p.strip()]

def find_entities(input_text, entity_type):
    txt = entity_type + "<|sep|>" + input_text + "<|end|>" # Important: need exact input format
    inputs = tokenizer(txt, return_tensors="pt").to(model.device)
    outputs = model.generate(**inputs, max_new_tokens=100)
    decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return custom_split(decoded)

# Example usage
input_text = "In the bustling metropolis of New York City, Apple Inc. sponsored a conference where Dr. Elena Rodriguez presented groundbreaking research about neuroscience and AI."
print(find_entities(input_text, "person")) # ['Elena Rodriguez']
print(find_entities(input_text, "company")) # ['Apple Inc.']
print(find_entities(input_text, "fruit")) # []
print(find_entities(input_text, "subject")) # ['neuroscience', 'AI']
```

## Limitations

- False positives and negatives are possible.
- May struggle with specialized knowledge or fine distinctions.
- Performance may vary for very short or long texts.
- English language only.
- Consider privacy when processing sensitive text.

## Training Procedure

### Training hyperparameters

The following hyperparameters were used during training:
- learning_rate: 5e-05
- train_batch_size: 8
- eval_batch_size: 8
- seed: 42
- optimizer: Use adamw_torch with betas=(0.9,0.999) and epsilon=1e-08 and optimizer_args=No additional optimizer arguments
- lr_scheduler_type: linear
- num_epochs: 5.0

### Training results

| Training Loss | Epoch | Step  | Validation Loss | Input Tokens Seen |
|:-------------:|:-----:|:-----:|:---------------:|:-----------------:|
| 0.8398        | 1.0   | 19991 | 0.6227          | 66451084          |
| 0.7203        | 2.0   | 39982 | 0.5679          | 132976438         |
| 0.6479        | 3.0   | 59973 | 0.5605          | 199402582         |
| 0.6023        | 4.0   | 79964 | 0.5427          | 265875340         |
| 0.5879        | 5.0   | 99955 | 0.5393          | 332318598         |

## Framework Versions

- Transformers: 4.46.3
- PyTorch: 2.5.1+cu124
- Datasets: 3.2.0
- Tokenizers: 0.20.3