Spaces:
Running
Running
Upload 23 files
Browse files- app.py +10 -0
- article.html +44 -0
- interfaces/__init__.py +4 -0
- interfaces/iupac2smiles.py +27 -0
- interfaces/iupac2style.py +22 -0
- interfaces/landing.py +6 -0
- interfaces/smiles2iupac.py +36 -0
- materials/introduction.html +67 -0
- modeling/__init__.py +2 -0
- modeling/config.py +133 -0
- modeling/docstrings.py +217 -0
- modeling/model.py +612 -0
- models/IUPAC2SMILES/config.json +33 -0
- models/IUPAC2SMILES/generation_config.json +7 -0
- models/IUPAC2SMILES/model.safetensors +3 -0
- models/SMILES2IUPAC/config.json +33 -0
- models/SMILES2IUPAC/generation_config.json +7 -0
- models/SMILES2IUPAC/model.safetensors +3 -0
- requirements.txt +3 -0
- test.py +12 -0
- utils/__init__.py +2 -0
- utils/main_model.py +47 -0
- utils/rdkit_utils.py +39 -0
app.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from interfaces import smiles2iupac, iupac2smiles, iupac2style, landing
|
3 |
+
|
4 |
+
|
5 |
+
demo = gr.TabbedInterface([landing, smiles2iupac, iupac2smiles, iupac2style],
|
6 |
+
["Introduction", "SMILES-to-IUPAC", "IUPAC-to-SMILES", "IUPAC style prediction"],
|
7 |
+
title="ChemConverters 🧪🔬🧬👨🏻🔬",
|
8 |
+
theme=gr.themes.Base())
|
9 |
+
|
10 |
+
demo.launch(share=True)
|
article.html
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html lang="en">
|
3 |
+
<head>
|
4 |
+
<meta charset="UTF-8">
|
5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
6 |
+
<title>ChemConverters App Description</title>
|
7 |
+
<style>
|
8 |
+
body {
|
9 |
+
font-family: Arial, sans-serif;
|
10 |
+
margin: 30px;
|
11 |
+
line-height: 4;
|
12 |
+
}
|
13 |
+
.link-button {
|
14 |
+
display: inline-block;
|
15 |
+
margin: 50px 50px;
|
16 |
+
padding: 50px;
|
17 |
+
background-color: #007bff;
|
18 |
+
color: white;
|
19 |
+
text-decoration: none;
|
20 |
+
border-radius: 50px;
|
21 |
+
font-weight: bold;
|
22 |
+
}
|
23 |
+
.link-button:hover {
|
24 |
+
background-color: #0056b3;
|
25 |
+
}
|
26 |
+
</style>
|
27 |
+
</head>
|
28 |
+
<body>
|
29 |
+
<p>With ChemConverters, you can effortlessly:</p>
|
30 |
+
<ul>
|
31 |
+
<li>Convert SMILES strings to IUPAC names and vice versa 🔄</li>
|
32 |
+
<li>Choose your preferred IUPAC naming style: BASE, SYSTEMATIC, or TRADITIONAL 📚</li>
|
33 |
+
<li>Validate chemical naming with molecules fingerprints similarity for accuracy checks ✔️</li>
|
34 |
+
</ul>
|
35 |
+
<p>Developed by the brilliant minds at Knowladgator, this app showcases the abilities of our chemical transformer models. Whether you're working on a research project, studying for an exam, or just exploring the chemical universe, ChemConverters is your go-to tool. 🛠️</p>
|
36 |
+
<p>Remember, chemistry is not just about reactions; it's about connections. Let's build those connections together! 💫</p>
|
37 |
+
<!-- Links Section -->
|
38 |
+
<div>
|
39 |
+
<a href="https://www.knowledgator.com/" class="link-button" target="_blank">🔗Visit our Website 🔗 </a>
|
40 |
+
<a href="https://www.linkedin.com/company/knowledgator/" class="link-button" target="_blank">💼Follow on LinkedIn 💼 </a>
|
41 |
+
<a href="https://huggingface.co/knowledgator/" class="link-button" target="_blank">🤗Hugging Face Profile🤗</a>
|
42 |
+
</div>
|
43 |
+
</body>
|
44 |
+
</html>
|
interfaces/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .smiles2iupac import smiles2iupac
|
2 |
+
from .iupac2smiles import iupac2smiles
|
3 |
+
from .iupac2style import iupac2style
|
4 |
+
from .landing import landing
|
interfaces/iupac2smiles.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from utils import ChemicalConverter, validate_smiles2iupac, plot_mol
|
3 |
+
|
4 |
+
def convert(chemical_name, plot):
|
5 |
+
# Initialize the ChemicalConverter
|
6 |
+
converter = ChemicalConverter(mode="IUPAC2SMILES")
|
7 |
+
converted_name = ""
|
8 |
+
plot_image = None
|
9 |
+
converted_name = converter.convert(chemical_name)[6:]
|
10 |
+
if plot:
|
11 |
+
plot_image = plot_mol(converted_name)
|
12 |
+
return converted_name, plot_image
|
13 |
+
|
14 |
+
|
15 |
+
iupac2smiles = gr.Interface(
|
16 |
+
fn=convert,
|
17 |
+
allow_flagging='auto',
|
18 |
+
inputs=[
|
19 |
+
gr.Textbox(label="Enter your IUPAC name", placeholder="Enter IUPAC name here"),
|
20 |
+
gr.Checkbox(label="Plot molecule", value=True)
|
21 |
+
],
|
22 |
+
outputs=[gr.Text(label="Converted Name"),
|
23 |
+
gr.Image(type='pil', label="Molecule Plot", height=170, width=890)],
|
24 |
+
examples=[
|
25 |
+
["ethanol", True]
|
26 |
+
],
|
27 |
+
)
|
interfaces/iupac2style.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from utils import ChemicalConverter, validate_smiles2iupac, plot_mol
|
3 |
+
|
4 |
+
def convert(chemical_name, plot):
|
5 |
+
# Initialize the ChemicalConverter
|
6 |
+
converter = ChemicalConverter(mode="IUPAC2SMILES")
|
7 |
+
converted_name = converter.convert(chemical_name)[:6]
|
8 |
+
styles = {"<SYST>": "SYSTEMATIC", "<TRAD>": "TRADITIONAL", "<BASE>": "BASE"}
|
9 |
+
return styles.get(converted_name, "")
|
10 |
+
|
11 |
+
|
12 |
+
iupac2style = gr.Interface(
|
13 |
+
fn=convert,
|
14 |
+
allow_flagging='auto',
|
15 |
+
inputs=[
|
16 |
+
gr.Textbox(label="Enter your IUPAC name", placeholder="Enter IUPAC name here"),
|
17 |
+
],
|
18 |
+
outputs=[gr.Text(label="IUPAC style")],
|
19 |
+
examples=[
|
20 |
+
["propan-2-yl 2-[4-(4-chlorophenyl)carbonylphenoxy]-2-methyl-propanoate"]
|
21 |
+
],
|
22 |
+
)
|
interfaces/landing.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
|
3 |
+
with open('materials/introduction.html', 'r', encoding='utf-8') as file:
|
4 |
+
html_description = file.read()
|
5 |
+
|
6 |
+
landing = gr.HTML(html_description)
|
interfaces/smiles2iupac.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from utils import ChemicalConverter, validate_smiles2iupac, plot_mol
|
3 |
+
|
4 |
+
def convert(chemical_name, style, validate, plot):
|
5 |
+
# Initialize the ChemicalConverter
|
6 |
+
converter = ChemicalConverter(mode="SMILES2IUPAC")
|
7 |
+
converted_name = ""
|
8 |
+
validation_score = ""
|
9 |
+
plot_image = None
|
10 |
+
style_prefix = "<" + style[:4] + ">"
|
11 |
+
converted_name = converter.convert(style_prefix + chemical_name)
|
12 |
+
if validate:
|
13 |
+
validation_score = validate_smiles2iupac(chemical_name, converted_name)
|
14 |
+
if plot:
|
15 |
+
plot_image = plot_mol(chemical_name)
|
16 |
+
return converted_name, validation_score, plot_image
|
17 |
+
|
18 |
+
smiles2iupac = gr.Interface(
|
19 |
+
fn=convert,
|
20 |
+
allow_flagging='auto',
|
21 |
+
inputs=[
|
22 |
+
gr.Textbox(label="Enter your SMILES name", placeholder="Enter your SMILES name here"),
|
23 |
+
gr.Radio(
|
24 |
+
choices=["BASE", "SYSTEMATIC", "TRADITIONAL"],
|
25 |
+
label="Choose desired IUPAC style",
|
26 |
+
),
|
27 |
+
gr.Checkbox(label="Validate with molecular similarity", value=False),
|
28 |
+
gr.Checkbox(label="Plot molecule", value=True)
|
29 |
+
],
|
30 |
+
outputs=[gr.Text(label="Converted Name"),
|
31 |
+
gr.Text(label="Input-Target similarity score"),
|
32 |
+
gr.Image(type='pil', label="Molecule Plot", height=170, width=890)],
|
33 |
+
examples=[
|
34 |
+
["CCO", "BASE", True, True]
|
35 |
+
],
|
36 |
+
)
|
materials/introduction.html
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html lang="en">
|
3 |
+
<head>
|
4 |
+
<meta charset="UTF-8">
|
5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
6 |
+
<title>ChemConverters App Description</title>
|
7 |
+
<style>
|
8 |
+
body {
|
9 |
+
font-family: Arial, sans-serif;
|
10 |
+
margin: 10px;
|
11 |
+
line-height: 1.6;
|
12 |
+
}
|
13 |
+
.link-button {
|
14 |
+
display: inline-flex;
|
15 |
+
align-items: center;
|
16 |
+
justify-content: center;
|
17 |
+
margin: 10px;
|
18 |
+
padding: 10px;
|
19 |
+
background-color: white;
|
20 |
+
border: 1px solid grey; /* Added border to make the button visible against white background */
|
21 |
+
color: #007bff; /* Text color changed to make it visible against white background */
|
22 |
+
text-decoration: none;
|
23 |
+
border-radius: 10px;
|
24 |
+
text-align: center;
|
25 |
+
vertical-align: middle;
|
26 |
+
box-sizing: border-box;
|
27 |
+
}
|
28 |
+
.link-button:hover {
|
29 |
+
background-color: #c0dcfc;
|
30 |
+
}
|
31 |
+
.link-button img {
|
32 |
+
height: 30px;
|
33 |
+
width: auto;
|
34 |
+
display: block;
|
35 |
+
}
|
36 |
+
.links-container {
|
37 |
+
text-align: center; /* Center the container's content */
|
38 |
+
margin: auto; /* Auto margins for horizontal centering if necessary */
|
39 |
+
display: flex; /* Use flexbox */
|
40 |
+
justify-content: center; /* Center flex items horizontally */
|
41 |
+
flex-wrap: wrap; /* Allow items to wrap */
|
42 |
+
}
|
43 |
+
</style>
|
44 |
+
</head>
|
45 |
+
<body>
|
46 |
+
<h2>Welcome to ChemConverters! 🧪🔬</h2>
|
47 |
+
<h3>With ChemConverters, you can effortlessly:</h3>
|
48 |
+
<ol>
|
49 |
+
<li>Convert SMILES strings to IUPAC names and vice versa 🔄</li>
|
50 |
+
<li>Choose your preferred IUPAC naming style: BASE, SYSTEMATIC, or TRADITIONAL 📚</li>
|
51 |
+
<li>Validate chemical naming with molecules fingerprints similarity for accuracy checks ✔️</li>
|
52 |
+
</ol>
|
53 |
+
<h3>What is ChemConverters?</h3>
|
54 |
+
<p>ChemConverters serves as a foundational showcase of our technological capabilities within the chemical domain. The models deployed in this application represent our entry-level offerings, designed to provide a glimpse into the potential applications of our advanced solutions. For access to our comprehensive suite of larger and more precise models, we invite interested parties to engage directly with us. Developed by the brilliant minds at Knowladgator, this app showcases the abilities of our chemical transformer models. Whether you're working on a research project, studying for an exam, or just exploring the chemical universe, ChemConverters is your go-to tool 🛠.<p>
|
55 |
+
<h3>Models Availability</h3>
|
56 |
+
<p>All models used in the applications are available on <a href="https://huggingface.co/knowledgator/" target="_blank">our Hugging Face page</a>. For translating from SMILES to IUPAC, the <a href="https://huggingface.co/knowledgator/SMILES2IUPAC-canonical-base" target="_blank">knowledgator/SMILES2IUPAC-canonical-base</a> model was used. To translate from IUPAC to SMILES, the <a href="https://huggingface.co/knowledgator/IUPAC2SMILES-canonical-base" target="_blank">knowledgator/IUPAC2SMILES-canonical-base</a> model was used.</p>
|
57 |
+
<h3>Citation</h3>
|
58 |
+
<p>Coming soon</p>
|
59 |
+
<h3>Remember, chemistry is not just about reactions; it's about connections. Let's build those connections together! 💫</h3>
|
60 |
+
<!-- Links Section -->
|
61 |
+
<div class="links-container">
|
62 |
+
<a href="https://www.knowledgator.com/" class="link-button" target="_blank"><img src="https://assets-global.website-files.com/65902be8ba48a05dfdb73331/6590476fcc8e8f35b2332781_Group%201000002504%20(1).png" alt="Visit our website"></a>
|
63 |
+
<a href="https://www.linkedin.com/company/knowledgator/" class="link-button" target="_blank"><img src="https://www.edigitalagency.com.au/wp-content/uploads/Linkedin-logo-png.png" alt="Follow on LinkedIn"></a>
|
64 |
+
<a href="https://huggingface.co/knowledgator/" class="link-button" target="_blank"><img src="https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo-with-title.png" alt="Hugging Face Profile"></a>
|
65 |
+
</div>
|
66 |
+
</body>
|
67 |
+
</html>
|
modeling/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .model import MT5ForConditionalGeneration
|
2 |
+
from .config import MT5Config
|
modeling/config.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PretrainedConfig
|
2 |
+
|
3 |
+
class MT5Config(PretrainedConfig):
|
4 |
+
r"""
|
5 |
+
This is the configuration class to store the configuration of a [`MT5Model`] or a [`TFMT5Model`]. It is used to
|
6 |
+
instantiate a mT5 model according to the specified arguments, defining the model architecture. Instantiating a
|
7 |
+
configuration with the defaults will yield a similar configuration to that of the mT5
|
8 |
+
[google/mt5-small](https://huggingface.co/google/mt5-small) architecture.
|
9 |
+
|
10 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
11 |
+
documentation from [`PretrainedConfig`] for more information.
|
12 |
+
|
13 |
+
Arguments:
|
14 |
+
vocab_size (`int`, *optional*, defaults to 250112):
|
15 |
+
Vocabulary size of the T5 model. Defines the number of different tokens that can be represented by the
|
16 |
+
`inputs_ids` passed when calling [`T5Model`] or [`TFT5Model`].
|
17 |
+
d_model (`int`, *optional*, defaults to 512):
|
18 |
+
Size of the encoder layers and the pooler layer.
|
19 |
+
d_kv (`int`, *optional*, defaults to 64):
|
20 |
+
Size of the key, query, value projections per attention head. In the conventional context, it is typically expected that `d_kv` has to be equal to `d_model // num_heads`.
|
21 |
+
But in the architecture of mt5-small, `d_kv` is not equal to `d_model //num_heads`. The `inner_dim` of the projection layer will be defined as `num_heads * d_kv`.
|
22 |
+
d_ff (`int`, *optional*, defaults to 1024):
|
23 |
+
Size of the intermediate feed forward layer in each `T5Block`.
|
24 |
+
num_layers (`int`, *optional*, defaults to 8):
|
25 |
+
Number of hidden layers in the Transformer encoder.
|
26 |
+
num_decoder_layers (`int`, *optional*):
|
27 |
+
Number of hidden layers in the Transformer decoder. Will use the same value as `num_layers` if not set.
|
28 |
+
num_heads (`int`, *optional*, defaults to 6):
|
29 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
30 |
+
relative_attention_num_buckets (`int`, *optional*, defaults to 32):
|
31 |
+
The number of buckets to use for each attention layer.
|
32 |
+
relative_attention_max_distance (`int`, *optional*, defaults to 128):
|
33 |
+
The maximum distance of the longer sequences for the bucket separation.
|
34 |
+
dropout_rate (`float`, *optional*, defaults to 0.1):
|
35 |
+
The ratio for all dropout layers.
|
36 |
+
classifier_dropout (`float`, *optional*, defaults to 0.0):
|
37 |
+
The dropout ratio for classifier.
|
38 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-6):
|
39 |
+
The epsilon used by the layer normalization layers.
|
40 |
+
initializer_factor (`float`, *optional*, defaults to 1):
|
41 |
+
A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
|
42 |
+
testing).
|
43 |
+
feed_forward_proj (`string`, *optional*, defaults to `"gated-gelu"`):
|
44 |
+
Type of feed forward layer to be used. Should be one of `"relu"` or `"gated-gelu"`.
|
45 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
46 |
+
Whether or not the model should return the last key/values attentions (not used by all models).
|
47 |
+
"""
|
48 |
+
|
49 |
+
model_type = "mt5"
|
50 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
51 |
+
|
52 |
+
def __init__(
|
53 |
+
self,
|
54 |
+
encoder_vocab_size=250112,
|
55 |
+
decoder_vocab_size=250112,
|
56 |
+
shared_embedding=False,
|
57 |
+
d_model=256,
|
58 |
+
d_kv=64,
|
59 |
+
d_ff=512,
|
60 |
+
num_layers=4,
|
61 |
+
num_decoder_layers=None,
|
62 |
+
num_heads=3,
|
63 |
+
relative_attention_num_buckets=32,
|
64 |
+
relative_attention_max_distance=128,
|
65 |
+
dropout_rate=0.1,
|
66 |
+
layer_norm_epsilon=1e-6,
|
67 |
+
initializer_factor=1.0,
|
68 |
+
feed_forward_proj="gated-gelu",
|
69 |
+
is_encoder_decoder=True,
|
70 |
+
use_cache=True,
|
71 |
+
tokenizer_class="ChemTokenizers.SMILES_IUPAC_FAST.FastTokenizer",
|
72 |
+
tie_word_embeddings=False,
|
73 |
+
pad_token_id=0,
|
74 |
+
eos_token_id=1,
|
75 |
+
decoder_start_token_id=2,
|
76 |
+
classifier_dropout=0.0,
|
77 |
+
**kwargs,
|
78 |
+
):
|
79 |
+
super().__init__(
|
80 |
+
is_encoder_decoder=is_encoder_decoder,
|
81 |
+
tokenizer_class=tokenizer_class,
|
82 |
+
tie_word_embeddings=tie_word_embeddings,
|
83 |
+
pad_token_id=pad_token_id,
|
84 |
+
eos_token_id=eos_token_id,
|
85 |
+
decoder_start_token_id=decoder_start_token_id,
|
86 |
+
**kwargs,
|
87 |
+
)
|
88 |
+
self.encoder_vocab_size = encoder_vocab_size
|
89 |
+
self.decoder_vocab_size = decoder_vocab_size
|
90 |
+
self.shared_embedding = shared_embedding
|
91 |
+
self.d_model = d_model
|
92 |
+
self.d_kv = d_kv
|
93 |
+
self.d_ff = d_ff
|
94 |
+
self.num_layers = num_layers
|
95 |
+
self.num_decoder_layers = (
|
96 |
+
num_decoder_layers if num_decoder_layers is not None else self.num_layers
|
97 |
+
) # default = symmetry
|
98 |
+
self.num_heads = num_heads
|
99 |
+
self.relative_attention_num_buckets = relative_attention_num_buckets
|
100 |
+
self.relative_attention_max_distance = relative_attention_max_distance
|
101 |
+
self.dropout_rate = dropout_rate
|
102 |
+
self.classifier_dropout = classifier_dropout
|
103 |
+
self.layer_norm_epsilon = layer_norm_epsilon
|
104 |
+
self.initializer_factor = initializer_factor
|
105 |
+
self.feed_forward_proj = feed_forward_proj
|
106 |
+
self.use_cache = use_cache
|
107 |
+
|
108 |
+
act_info = self.feed_forward_proj.split("-")
|
109 |
+
self.dense_act_fn = act_info[-1]
|
110 |
+
self.is_gated_act = act_info[0] == "gated"
|
111 |
+
|
112 |
+
if len(act_info) > 1 and act_info[0] != "gated" or len(act_info) > 2:
|
113 |
+
raise ValueError(
|
114 |
+
f"`feed_forward_proj`: {feed_forward_proj} is not a valid activation function of the dense layer. "
|
115 |
+
"Please make sure `feed_forward_proj` is of the format `gated-{ACT_FN}` or `{ACT_FN}`, e.g. "
|
116 |
+
"'gated-gelu' or 'relu'"
|
117 |
+
)
|
118 |
+
|
119 |
+
# for backwards compatibility
|
120 |
+
if feed_forward_proj == "gated-gelu":
|
121 |
+
self.dense_act_fn = "gelu_new"
|
122 |
+
|
123 |
+
@property
|
124 |
+
def hidden_size(self):
|
125 |
+
return self.d_model
|
126 |
+
|
127 |
+
@property
|
128 |
+
def num_attention_heads(self):
|
129 |
+
return self.num_heads
|
130 |
+
|
131 |
+
@property
|
132 |
+
def num_hidden_layers(self):
|
133 |
+
return self.num_layers
|
modeling/docstrings.py
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
PARALLELIZE_DOCSTRING = r"""
|
2 |
+
This is an experimental feature and is a subject to change at a moment's notice.
|
3 |
+
|
4 |
+
Uses a device map to distribute attention modules of the model across several devices. If no device map is given,
|
5 |
+
it will evenly distribute blocks across all devices.
|
6 |
+
|
7 |
+
Args:
|
8 |
+
device_map (`Dict[int, list]`, optional, defaults to None):
|
9 |
+
A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always
|
10 |
+
automatically mapped to the first device (for esoteric reasons). That means that the first device should
|
11 |
+
have fewer attention modules mapped to it than other devices. For reference, the mt5 models have the
|
12 |
+
following number of attention modules:
|
13 |
+
|
14 |
+
- mt5-small: 6
|
15 |
+
- mt5-base: 12
|
16 |
+
- mt5-large: 24
|
17 |
+
- mt5-xl: 24
|
18 |
+
- mt5-xxl: 24
|
19 |
+
|
20 |
+
Example:
|
21 |
+
|
22 |
+
```python
|
23 |
+
# Here is an example of a device map on a machine with 4 GPUs using mt5-xl, which has a total of 24 attention modules:
|
24 |
+
model = MT5ForConditionalGeneration.from_pretrained("mt5-xl")
|
25 |
+
device_map = {
|
26 |
+
0: [0, 1, 2],
|
27 |
+
1: [3, 4, 5, 6, 7, 8, 9],
|
28 |
+
2: [10, 11, 12, 13, 14, 15, 16],
|
29 |
+
3: [17, 18, 19, 20, 21, 22, 23],
|
30 |
+
}
|
31 |
+
model.parallelize(device_map)
|
32 |
+
```
|
33 |
+
"""
|
34 |
+
DEPARALLELIZE_DOCSTRING = r"""
|
35 |
+
Moves the model to cpu from a model parallel state.
|
36 |
+
|
37 |
+
Example:
|
38 |
+
|
39 |
+
```python
|
40 |
+
# On a 4 GPU machine with mt5-xl:
|
41 |
+
model = MT5ForConditionalGeneration.from_pretrained("Mt5-xl")
|
42 |
+
device_map = {
|
43 |
+
0: [0, 1, 2],
|
44 |
+
1: [3, 4, 5, 6, 7, 8, 9],
|
45 |
+
2: [10, 11, 12, 13, 14, 15, 16],
|
46 |
+
3: [17, 18, 19, 20, 21, 22, 23],
|
47 |
+
}
|
48 |
+
model.parallelize(device_map) # Splits the model across several devices
|
49 |
+
model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache()
|
50 |
+
```
|
51 |
+
"""
|
52 |
+
|
53 |
+
__HEAD_MASK_WARNING_MSG = """
|
54 |
+
The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently,
|
55 |
+
`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions.
|
56 |
+
If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers,
|
57 |
+
num_heads)`.
|
58 |
+
"""
|
59 |
+
|
60 |
+
MT5_START_DOCSTRING = r"""
|
61 |
+
|
62 |
+
The MT5 model was proposed in [Exploring the Limits of Transfer Learning with a Unified Text-to-Text
|
63 |
+
Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan
|
64 |
+
Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer pre-trained in a
|
65 |
+
text-to-text denoising generative setting.
|
66 |
+
|
67 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
68 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
69 |
+
etc.)
|
70 |
+
|
71 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
72 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
73 |
+
and behavior.
|
74 |
+
|
75 |
+
Parameters:
|
76 |
+
config ([`MT5Config`]): Model configuration class with all the parameters of the model.
|
77 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
78 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
79 |
+
"""
|
80 |
+
|
81 |
+
MT5_INPUTS_DOCSTRING = r"""
|
82 |
+
Args:
|
83 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
84 |
+
Indices of input sequence tokens in the vocabulary. MT5 is a model with relative position embeddings so you
|
85 |
+
should be able to pad the inputs on both the right and the left.
|
86 |
+
|
87 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
88 |
+
[`PreTrainedTokenizer.__call__`] for detail.
|
89 |
+
|
90 |
+
[What are input IDs?](../glossary#input-ids)
|
91 |
+
|
92 |
+
To know more on how to prepare `input_ids` for pretraining take a look a [MT5 Training](./mt5#training).
|
93 |
+
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
94 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
95 |
+
|
96 |
+
- 1 for tokens that are **not masked**,
|
97 |
+
- 0 for tokens that are **masked**.
|
98 |
+
|
99 |
+
[What are attention masks?](../glossary#attention-mask)
|
100 |
+
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
101 |
+
Indices of decoder input sequence tokens in the vocabulary.
|
102 |
+
|
103 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
104 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
105 |
+
|
106 |
+
[What are decoder input IDs?](../glossary#decoder-input-ids)
|
107 |
+
|
108 |
+
MT5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
|
109 |
+
is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
|
110 |
+
|
111 |
+
To know more on how to prepare `decoder_input_ids` for pretraining take a look at [MT5
|
112 |
+
Training](./mt5#training).
|
113 |
+
decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
114 |
+
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
|
115 |
+
be used by default.
|
116 |
+
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
117 |
+
Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in `[0,
|
118 |
+
1]`:
|
119 |
+
|
120 |
+
- 1 indicates the head is **not masked**,
|
121 |
+
- 0 indicates the head is **masked**.
|
122 |
+
|
123 |
+
decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
124 |
+
Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0,
|
125 |
+
1]`:
|
126 |
+
|
127 |
+
- 1 indicates the head is **not masked**,
|
128 |
+
- 0 indicates the head is **masked**.
|
129 |
+
|
130 |
+
cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
131 |
+
Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in
|
132 |
+
`[0, 1]`:
|
133 |
+
|
134 |
+
- 1 indicates the head is **not masked**,
|
135 |
+
- 0 indicates the head is **masked**.
|
136 |
+
|
137 |
+
encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
|
138 |
+
Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: *attentions*)
|
139 |
+
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden states at
|
140 |
+
the output of the last layer of the encoder. Used in the cross-attention of the decoder.
|
141 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
142 |
+
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
143 |
+
|
144 |
+
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
145 |
+
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
146 |
+
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
147 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
148 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
149 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
150 |
+
model's internal embedding lookup matrix.
|
151 |
+
decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
|
152 |
+
Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
|
153 |
+
representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be
|
154 |
+
input (see `past_key_values`). This is useful if you want more control over how to convert
|
155 |
+
`decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
|
156 |
+
|
157 |
+
If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
|
158 |
+
of `inputs_embeds`.
|
159 |
+
|
160 |
+
use_cache (`bool`, *optional*):
|
161 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
162 |
+
`past_key_values`).
|
163 |
+
|
164 |
+
output_attentions (`bool`, *optional*):
|
165 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
166 |
+
tensors for more detail.
|
167 |
+
output_hidden_states (`bool`, *optional*):
|
168 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
169 |
+
more detail.
|
170 |
+
return_dict (`bool`, *optional*):
|
171 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
172 |
+
"""
|
173 |
+
|
174 |
+
MT5_ENCODER_INPUTS_DOCSTRING = r"""
|
175 |
+
Args:
|
176 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
177 |
+
Indices of input sequence tokens in the vocabulary. MT5 is a model with relative position embeddings so you
|
178 |
+
should be able to pad the inputs on both the right and the left.
|
179 |
+
|
180 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
181 |
+
[`PreTrainedTokenizer.__call__`] for detail.
|
182 |
+
|
183 |
+
To know more on how to prepare `input_ids` for pretraining take a look a [MT5 Training](./mt5#training).
|
184 |
+
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
185 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
186 |
+
|
187 |
+
- 1 for tokens that are **not masked**,
|
188 |
+
- 0 for tokens that are **masked**.
|
189 |
+
|
190 |
+
[What are attention masks?](../glossary#attention-mask)
|
191 |
+
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
192 |
+
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
193 |
+
|
194 |
+
- 1 indicates the head is **not masked**,
|
195 |
+
- 0 indicates the head is **masked**.
|
196 |
+
|
197 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
198 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
199 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
200 |
+
model's internal embedding lookup matrix.
|
201 |
+
output_attentions (`bool`, *optional*):
|
202 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
203 |
+
tensors for more detail.
|
204 |
+
output_hidden_states (`bool`, *optional*):
|
205 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
206 |
+
more detail.
|
207 |
+
return_dict (`bool`, *optional*):
|
208 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
209 |
+
"""
|
210 |
+
|
211 |
+
# Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
|
212 |
+
__HEAD_MASK_WARNING_MSG = """
|
213 |
+
The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently,
|
214 |
+
`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions.
|
215 |
+
If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers,
|
216 |
+
num_heads)`.
|
217 |
+
"""
|
modeling/model.py
ADDED
@@ -0,0 +1,612 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import math
|
3 |
+
import warnings
|
4 |
+
from typing import List, Optional, Tuple, Union
|
5 |
+
|
6 |
+
from transformers import MT5PreTrainedModel
|
7 |
+
from transformers.models.mt5 import MT5Stack
|
8 |
+
from transformers.modeling_outputs import Seq2SeqModelOutput,Seq2SeqLMOutput, BaseModelOutput
|
9 |
+
from transformers.utils import (
|
10 |
+
add_start_docstrings,
|
11 |
+
add_start_docstrings_to_model_forward,
|
12 |
+
logging,
|
13 |
+
replace_return_docstrings,
|
14 |
+
)
|
15 |
+
|
16 |
+
from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
|
17 |
+
|
18 |
+
import torch
|
19 |
+
from torch import nn
|
20 |
+
from torch.nn import CrossEntropyLoss
|
21 |
+
|
22 |
+
from .config import MT5Config
|
23 |
+
from .docstrings import (
|
24 |
+
PARALLELIZE_DOCSTRING,
|
25 |
+
DEPARALLELIZE_DOCSTRING,
|
26 |
+
__HEAD_MASK_WARNING_MSG,
|
27 |
+
MT5_START_DOCSTRING,
|
28 |
+
MT5_INPUTS_DOCSTRING,
|
29 |
+
)
|
30 |
+
|
31 |
+
|
32 |
+
logger = logging.get_logger(__name__)
|
33 |
+
|
34 |
+
_CONFIG_FOR_DOC = "MT5Config"
|
35 |
+
_CHECKPOINT_FOR_DOC = "mt5-small"
|
36 |
+
|
37 |
+
|
38 |
+
class MT5Model(MT5PreTrainedModel):
|
39 |
+
r"""
|
40 |
+
Examples:
|
41 |
+
|
42 |
+
```python
|
43 |
+
>>> from transformers import MT5Model, AutoTokenizer
|
44 |
+
|
45 |
+
>>> model = MT5Model.from_pretrained("google/mt5-small")
|
46 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("google/mt5-small")
|
47 |
+
>>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
|
48 |
+
>>> summary = "Weiter Verhandlung in Syrien."
|
49 |
+
>>> inputs = tokenizer(article, return_tensors="pt")
|
50 |
+
>>> labels = tokenizer(text_target=summary, return_tensors="pt")
|
51 |
+
|
52 |
+
>>> outputs = model(input_ids=inputs["input_ids"], decoder_input_ids=labels["input_ids"])
|
53 |
+
>>> hidden_states = outputs.last_hidden_state
|
54 |
+
```"""
|
55 |
+
|
56 |
+
model_type = "mt5"
|
57 |
+
config_class = MT5Config
|
58 |
+
_keys_to_ignore_on_load_missing = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"]
|
59 |
+
_keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"]
|
60 |
+
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
|
61 |
+
|
62 |
+
# Copied from transformers.models.t5.modeling_t5.T5Model.__init__ with T5->MT5
|
63 |
+
def __init__(self, config: MT5Config):
|
64 |
+
super().__init__(config)
|
65 |
+
self.encoder_embedding = nn.Embedding(config.encoder_vocab_size, config.d_model)
|
66 |
+
if config.shared_embedding:
|
67 |
+
self.decoder_embedding = self.encoder_embedding
|
68 |
+
else:
|
69 |
+
self.decoder_emebedding = nn.Embedding(config.decoder_vocab_size, config.d_model)
|
70 |
+
|
71 |
+
encoder_config = copy.deepcopy(config)
|
72 |
+
encoder_config.is_decoder = False
|
73 |
+
encoder_config.use_cache = False
|
74 |
+
encoder_config.is_encoder_decoder = False
|
75 |
+
self.encoder = MT5Stack(encoder_config, self.encoder_embedding)
|
76 |
+
|
77 |
+
decoder_config = copy.deepcopy(config)
|
78 |
+
decoder_config.is_decoder = True
|
79 |
+
decoder_config.is_encoder_decoder = False
|
80 |
+
decoder_config.num_layers = config.num_decoder_layers
|
81 |
+
self.decoder = MT5Stack(decoder_config, self.decoder_emebedding)
|
82 |
+
|
83 |
+
# Initialize weights and apply final processing
|
84 |
+
self.post_init()
|
85 |
+
|
86 |
+
# Model parallel
|
87 |
+
self.model_parallel = False
|
88 |
+
self.device_map = None
|
89 |
+
|
90 |
+
# Copied from transformers.models.t5.modeling_t5.T5Model.parallelize
|
91 |
+
def parallelize(self, device_map=None):
|
92 |
+
warnings.warn(
|
93 |
+
"`T5Model.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your model"
|
94 |
+
" with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
|
95 |
+
" `device_map` but it needs to be a dictionary module_name to device, so for instance {'encoder.block.0':"
|
96 |
+
" 0, 'encoder.block.1': 1, ...}",
|
97 |
+
FutureWarning,
|
98 |
+
)
|
99 |
+
self.device_map = (
|
100 |
+
get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
|
101 |
+
if device_map is None
|
102 |
+
else device_map
|
103 |
+
)
|
104 |
+
assert_device_map(self.device_map, len(self.encoder.block))
|
105 |
+
self.encoder.parallelize(self.device_map)
|
106 |
+
self.decoder.parallelize(self.device_map)
|
107 |
+
self.model_parallel = True
|
108 |
+
|
109 |
+
@add_start_docstrings(DEPARALLELIZE_DOCSTRING)
|
110 |
+
# Copied from transformers.models.t5.modeling_t5.T5Model.deparallelize
|
111 |
+
def deparallelize(self):
|
112 |
+
warnings.warn(
|
113 |
+
"Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
|
114 |
+
FutureWarning,
|
115 |
+
)
|
116 |
+
self.encoder.deparallelize()
|
117 |
+
self.decoder.deparallelize()
|
118 |
+
self.encoder = self.encoder.to("cpu")
|
119 |
+
self.decoder = self.decoder.to("cpu")
|
120 |
+
self.model_parallel = False
|
121 |
+
self.device_map = None
|
122 |
+
torch.cuda.empty_cache()
|
123 |
+
|
124 |
+
# Copied from transformers.models.t5.modeling_t5.T5Model.get_input_embeddings
|
125 |
+
def get_input_embeddings(self):
|
126 |
+
return self.encoder_embedding
|
127 |
+
|
128 |
+
# Copied from transformers.models.t5.modeling_t5.T5Model.set_input_embeddings
|
129 |
+
def set_input_embeddings(self, new_embeddings):
|
130 |
+
self.encoder_embedding = new_embeddings
|
131 |
+
self.encoder.set_input_embeddings(new_embeddings)
|
132 |
+
self.decoder.set_input_embeddings(new_embeddings)
|
133 |
+
|
134 |
+
# Copied from transformers.models.t5.modeling_t5.T5Model.get_encoder
|
135 |
+
def get_encoder(self):
|
136 |
+
return self.encoder
|
137 |
+
|
138 |
+
# Copied from transformers.models.t5.modeling_t5.T5Model.get_decoder
|
139 |
+
def get_decoder(self):
|
140 |
+
return self.decoder
|
141 |
+
|
142 |
+
# Copied from transformers.models.t5.modeling_t5.T5Model._prune_heads
|
143 |
+
def _prune_heads(self, heads_to_prune):
|
144 |
+
"""
|
145 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
146 |
+
class PreTrainedModel
|
147 |
+
"""
|
148 |
+
for layer, heads in heads_to_prune.items():
|
149 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
150 |
+
|
151 |
+
@replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
|
152 |
+
# Copied from transformers.models.t5.modeling_t5.T5Model.forward with T5->MT5, t5->mt5
|
153 |
+
def forward(
|
154 |
+
self,
|
155 |
+
input_ids: Optional[torch.LongTensor] = None,
|
156 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
157 |
+
decoder_input_ids: Optional[torch.LongTensor] = None,
|
158 |
+
decoder_attention_mask: Optional[torch.BoolTensor] = None,
|
159 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
160 |
+
decoder_head_mask: Optional[torch.FloatTensor] = None,
|
161 |
+
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
162 |
+
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
163 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
164 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
165 |
+
decoder_inputs_embeds: Optional[torch.Tensor] = None,
|
166 |
+
use_cache: Optional[bool] = None,
|
167 |
+
output_attentions: Optional[bool] = None,
|
168 |
+
output_hidden_states: Optional[bool] = None,
|
169 |
+
return_dict: Optional[bool] = None,
|
170 |
+
) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]:
|
171 |
+
r"""
|
172 |
+
Returns:
|
173 |
+
|
174 |
+
Example:
|
175 |
+
|
176 |
+
```python
|
177 |
+
>>> from transformers import AutoTokenizer, MT5Model
|
178 |
+
|
179 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("mt5-small")
|
180 |
+
>>> model = MT5Model.from_pretrained("mt5-small")
|
181 |
+
|
182 |
+
>>> input_ids = tokenizer(
|
183 |
+
... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
|
184 |
+
... ).input_ids # Batch size 1
|
185 |
+
>>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1
|
186 |
+
|
187 |
+
>>> # preprocess: Prepend decoder_input_ids with start token which is pad token for MT5Model.
|
188 |
+
>>> # This is not needed for torch's MT5ForConditionalGeneration as it does this internally using labels arg.
|
189 |
+
>>> decoder_input_ids = model._shift_right(decoder_input_ids)
|
190 |
+
|
191 |
+
>>> # forward pass
|
192 |
+
>>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
|
193 |
+
>>> last_hidden_states = outputs.last_hidden_state
|
194 |
+
```"""
|
195 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
196 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
197 |
+
|
198 |
+
# FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
|
199 |
+
if head_mask is not None and decoder_head_mask is None:
|
200 |
+
if self.config.num_layers == self.config.num_decoder_layers:
|
201 |
+
warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
|
202 |
+
decoder_head_mask = head_mask
|
203 |
+
|
204 |
+
# Encode if needed (training, first prediction pass)
|
205 |
+
if encoder_outputs is None:
|
206 |
+
encoder_outputs = self.encoder(
|
207 |
+
input_ids=input_ids,
|
208 |
+
attention_mask=attention_mask,
|
209 |
+
inputs_embeds=inputs_embeds,
|
210 |
+
head_mask=head_mask,
|
211 |
+
output_attentions=output_attentions,
|
212 |
+
output_hidden_states=output_hidden_states,
|
213 |
+
return_dict=return_dict,
|
214 |
+
)
|
215 |
+
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
|
216 |
+
encoder_outputs = BaseModelOutput(
|
217 |
+
last_hidden_state=encoder_outputs[0],
|
218 |
+
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
|
219 |
+
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
|
220 |
+
)
|
221 |
+
|
222 |
+
hidden_states = encoder_outputs[0]
|
223 |
+
|
224 |
+
# Set device for model parallelism
|
225 |
+
if self.model_parallel:
|
226 |
+
torch.cuda.set_device(self.decoder.first_device)
|
227 |
+
hidden_states = hidden_states.to(self.decoder.first_device)
|
228 |
+
if decoder_input_ids is not None:
|
229 |
+
decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
|
230 |
+
if attention_mask is not None:
|
231 |
+
attention_mask = attention_mask.to(self.decoder.first_device)
|
232 |
+
if decoder_attention_mask is not None:
|
233 |
+
decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
|
234 |
+
|
235 |
+
# Decode
|
236 |
+
decoder_outputs = self.decoder(
|
237 |
+
input_ids=decoder_input_ids,
|
238 |
+
attention_mask=decoder_attention_mask,
|
239 |
+
inputs_embeds=decoder_inputs_embeds,
|
240 |
+
past_key_values=past_key_values,
|
241 |
+
encoder_hidden_states=hidden_states,
|
242 |
+
encoder_attention_mask=attention_mask,
|
243 |
+
head_mask=decoder_head_mask,
|
244 |
+
cross_attn_head_mask=cross_attn_head_mask,
|
245 |
+
use_cache=use_cache,
|
246 |
+
output_attentions=output_attentions,
|
247 |
+
output_hidden_states=output_hidden_states,
|
248 |
+
return_dict=return_dict,
|
249 |
+
)
|
250 |
+
|
251 |
+
if not return_dict:
|
252 |
+
return decoder_outputs + encoder_outputs
|
253 |
+
|
254 |
+
return Seq2SeqModelOutput(
|
255 |
+
last_hidden_state=decoder_outputs.last_hidden_state,
|
256 |
+
past_key_values=decoder_outputs.past_key_values,
|
257 |
+
decoder_hidden_states=decoder_outputs.hidden_states,
|
258 |
+
decoder_attentions=decoder_outputs.attentions,
|
259 |
+
cross_attentions=decoder_outputs.cross_attentions,
|
260 |
+
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
261 |
+
encoder_hidden_states=encoder_outputs.hidden_states,
|
262 |
+
encoder_attentions=encoder_outputs.attentions,
|
263 |
+
)
|
264 |
+
|
265 |
+
|
266 |
+
@add_start_docstrings("""MT5 Model with a `language modeling` head on top.""", MT5_START_DOCSTRING)
|
267 |
+
class MT5ForConditionalGeneration(MT5PreTrainedModel):
|
268 |
+
r"""
|
269 |
+
Examples:
|
270 |
+
|
271 |
+
```python
|
272 |
+
>>> from transformers import MT5ForConditionalGeneration, AutoTokenizer
|
273 |
+
|
274 |
+
>>> model = MT5ForConditionalGeneration.from_pretrained("google/mt5-small")
|
275 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("google/mt5-small")
|
276 |
+
>>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
|
277 |
+
>>> summary = "Weiter Verhandlung in Syrien."
|
278 |
+
>>> inputs = tokenizer(article, text_target=summary, return_tensors="pt")
|
279 |
+
|
280 |
+
>>> outputs = model(**inputs)
|
281 |
+
>>> loss = outputs.loss
|
282 |
+
```"""
|
283 |
+
|
284 |
+
model_type = "mt5"
|
285 |
+
config_class = MT5Config
|
286 |
+
_keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"]
|
287 |
+
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
|
288 |
+
|
289 |
+
# Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.__init__ with T5->MT5
|
290 |
+
def __init__(self, config: MT5Config):
|
291 |
+
super().__init__(config)
|
292 |
+
self.model_dim = config.d_model
|
293 |
+
|
294 |
+
self.encoder_embedding = nn.Embedding(config.encoder_vocab_size, config.d_model)
|
295 |
+
if config.shared_embedding:
|
296 |
+
self.decoder_embedding = self.encoder_embedding
|
297 |
+
else:
|
298 |
+
self.decoder_emebedding = nn.Embedding(config.decoder_vocab_size, config.d_model)
|
299 |
+
|
300 |
+
encoder_config = copy.deepcopy(config)
|
301 |
+
encoder_config.is_decoder = False
|
302 |
+
encoder_config.use_cache = False
|
303 |
+
encoder_config.is_encoder_decoder = False
|
304 |
+
self.encoder = MT5Stack(encoder_config, self.encoder_embedding)
|
305 |
+
|
306 |
+
decoder_config = copy.deepcopy(config)
|
307 |
+
decoder_config.is_decoder = True
|
308 |
+
decoder_config.is_encoder_decoder = False
|
309 |
+
decoder_config.num_layers = config.num_decoder_layers
|
310 |
+
self.decoder = MT5Stack(decoder_config, self.decoder_emebedding)
|
311 |
+
|
312 |
+
self.lm_head = nn.Linear(config.d_model, config.decoder_vocab_size, bias=False)
|
313 |
+
|
314 |
+
# Initialize weights and apply final processing
|
315 |
+
self.post_init()
|
316 |
+
|
317 |
+
# Model parallel
|
318 |
+
self.model_parallel = False
|
319 |
+
self.device_map = None
|
320 |
+
|
321 |
+
@add_start_docstrings(PARALLELIZE_DOCSTRING)
|
322 |
+
# Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.parallelize
|
323 |
+
def parallelize(self, device_map=None):
|
324 |
+
warnings.warn(
|
325 |
+
"`T5ForConditionalGeneration.parallelize` is deprecated and will be removed in v5 of Transformers, you"
|
326 |
+
" should load your model with `device_map='balanced'` in the call to `from_pretrained`. You can also"
|
327 |
+
" provide your own `device_map` but it needs to be a dictionary module_name to device, so for instance"
|
328 |
+
" {'encoder.block.0': 0, 'encoder.block.1': 1, ...}",
|
329 |
+
FutureWarning,
|
330 |
+
)
|
331 |
+
self.device_map = (
|
332 |
+
get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
|
333 |
+
if device_map is None
|
334 |
+
else device_map
|
335 |
+
)
|
336 |
+
assert_device_map(self.device_map, len(self.encoder.block))
|
337 |
+
self.encoder.parallelize(self.device_map)
|
338 |
+
self.decoder.parallelize(self.device_map)
|
339 |
+
self.lm_head = self.lm_head.to(self.decoder.first_device)
|
340 |
+
self.model_parallel = True
|
341 |
+
|
342 |
+
@add_start_docstrings(DEPARALLELIZE_DOCSTRING)
|
343 |
+
# Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.deparallelize
|
344 |
+
def deparallelize(self):
|
345 |
+
warnings.warn(
|
346 |
+
"Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
|
347 |
+
FutureWarning,
|
348 |
+
)
|
349 |
+
self.encoder.deparallelize()
|
350 |
+
self.decoder.deparallelize()
|
351 |
+
self.encoder = self.encoder.to("cpu")
|
352 |
+
self.decoder = self.decoder.to("cpu")
|
353 |
+
self.lm_head = self.lm_head.to("cpu")
|
354 |
+
self.model_parallel = False
|
355 |
+
self.device_map = None
|
356 |
+
torch.cuda.empty_cache()
|
357 |
+
|
358 |
+
# Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.get_input_embeddings
|
359 |
+
def get_input_embeddings(self):
|
360 |
+
return self.encoder_embedding
|
361 |
+
|
362 |
+
# Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.set_input_embeddings
|
363 |
+
def set_input_embeddings(self, new_embeddings):
|
364 |
+
self.encoder_embedding = new_embeddings
|
365 |
+
self.encoder.set_input_embeddings(new_embeddings)
|
366 |
+
self.decoder.set_input_embeddings(new_embeddings)
|
367 |
+
|
368 |
+
# Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.set_output_embeddings
|
369 |
+
def set_output_embeddings(self, new_embeddings):
|
370 |
+
self.lm_head = new_embeddings
|
371 |
+
|
372 |
+
# Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.get_output_embeddings
|
373 |
+
def get_output_embeddings(self):
|
374 |
+
return self.lm_head
|
375 |
+
|
376 |
+
# Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.get_encoder
|
377 |
+
def get_encoder(self):
|
378 |
+
return self.encoder
|
379 |
+
|
380 |
+
# Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.get_decoder
|
381 |
+
def get_decoder(self):
|
382 |
+
return self.decoder
|
383 |
+
|
384 |
+
@add_start_docstrings_to_model_forward(MT5_INPUTS_DOCSTRING)
|
385 |
+
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
|
386 |
+
# Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.forward with T5->MT5, t5->mt5
|
387 |
+
def forward(
|
388 |
+
self,
|
389 |
+
input_ids: Optional[torch.LongTensor] = None,
|
390 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
391 |
+
decoder_input_ids: Optional[torch.LongTensor] = None,
|
392 |
+
decoder_attention_mask: Optional[torch.BoolTensor] = None,
|
393 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
394 |
+
decoder_head_mask: Optional[torch.FloatTensor] = None,
|
395 |
+
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
396 |
+
encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
397 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
398 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
399 |
+
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
400 |
+
labels: Optional[torch.LongTensor] = None,
|
401 |
+
use_cache: Optional[bool] = None,
|
402 |
+
output_attentions: Optional[bool] = None,
|
403 |
+
output_hidden_states: Optional[bool] = None,
|
404 |
+
return_dict: Optional[bool] = None,
|
405 |
+
) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
|
406 |
+
r"""
|
407 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
408 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
|
409 |
+
config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
|
410 |
+
labels in `[0, ..., config.vocab_size]`
|
411 |
+
|
412 |
+
Returns:
|
413 |
+
|
414 |
+
Examples:
|
415 |
+
|
416 |
+
```python
|
417 |
+
>>> from transformers import AutoTokenizer, MT5ForConditionalGeneration
|
418 |
+
|
419 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("mt5-small")
|
420 |
+
>>> model = MT5ForConditionalGeneration.from_pretrained("mt5-small")
|
421 |
+
|
422 |
+
>>> # training
|
423 |
+
>>> input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids
|
424 |
+
>>> labels = tokenizer("<extra_id_0> cute dog <extra_id_1> the <extra_id_2>", return_tensors="pt").input_ids
|
425 |
+
>>> outputs = model(input_ids=input_ids, labels=labels)
|
426 |
+
>>> loss = outputs.loss
|
427 |
+
>>> logits = outputs.logits
|
428 |
+
|
429 |
+
>>> # inference
|
430 |
+
>>> input_ids = tokenizer(
|
431 |
+
... "summarize: studies have shown that owning a dog is good for you", return_tensors="pt"
|
432 |
+
... ).input_ids # Batch size 1
|
433 |
+
>>> outputs = model.generate(input_ids)
|
434 |
+
>>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
435 |
+
>>> # studies have shown that owning a dog is good for you.
|
436 |
+
```"""
|
437 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
438 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
439 |
+
|
440 |
+
# FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
|
441 |
+
if head_mask is not None and decoder_head_mask is None:
|
442 |
+
if self.config.num_layers == self.config.num_decoder_layers:
|
443 |
+
warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
|
444 |
+
decoder_head_mask = head_mask
|
445 |
+
|
446 |
+
# Encode if needed (training, first prediction pass)
|
447 |
+
if encoder_outputs is None:
|
448 |
+
# Convert encoder inputs in embeddings if needed
|
449 |
+
encoder_outputs = self.encoder(
|
450 |
+
input_ids=input_ids,
|
451 |
+
attention_mask=attention_mask,
|
452 |
+
inputs_embeds=inputs_embeds,
|
453 |
+
head_mask=head_mask,
|
454 |
+
output_attentions=output_attentions,
|
455 |
+
output_hidden_states=output_hidden_states,
|
456 |
+
return_dict=return_dict,
|
457 |
+
)
|
458 |
+
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
|
459 |
+
encoder_outputs = BaseModelOutput(
|
460 |
+
last_hidden_state=encoder_outputs[0],
|
461 |
+
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
|
462 |
+
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
|
463 |
+
)
|
464 |
+
|
465 |
+
hidden_states = encoder_outputs[0]
|
466 |
+
|
467 |
+
if self.model_parallel:
|
468 |
+
torch.cuda.set_device(self.decoder.first_device)
|
469 |
+
|
470 |
+
if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
|
471 |
+
# get decoder inputs from shifting lm labels to the right
|
472 |
+
decoder_input_ids = self._shift_right(labels)
|
473 |
+
|
474 |
+
# Set device for model parallelism
|
475 |
+
if self.model_parallel:
|
476 |
+
torch.cuda.set_device(self.decoder.first_device)
|
477 |
+
hidden_states = hidden_states.to(self.decoder.first_device)
|
478 |
+
if decoder_input_ids is not None:
|
479 |
+
decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
|
480 |
+
if attention_mask is not None:
|
481 |
+
attention_mask = attention_mask.to(self.decoder.first_device)
|
482 |
+
if decoder_attention_mask is not None:
|
483 |
+
decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
|
484 |
+
|
485 |
+
# Decode
|
486 |
+
decoder_outputs = self.decoder(
|
487 |
+
input_ids=decoder_input_ids,
|
488 |
+
attention_mask=decoder_attention_mask,
|
489 |
+
inputs_embeds=decoder_inputs_embeds,
|
490 |
+
past_key_values=past_key_values,
|
491 |
+
encoder_hidden_states=hidden_states,
|
492 |
+
encoder_attention_mask=attention_mask,
|
493 |
+
head_mask=decoder_head_mask,
|
494 |
+
cross_attn_head_mask=cross_attn_head_mask,
|
495 |
+
use_cache=use_cache,
|
496 |
+
output_attentions=output_attentions,
|
497 |
+
output_hidden_states=output_hidden_states,
|
498 |
+
return_dict=return_dict,
|
499 |
+
)
|
500 |
+
|
501 |
+
sequence_output = decoder_outputs[0]
|
502 |
+
|
503 |
+
# Set device for model parallelism
|
504 |
+
if self.model_parallel:
|
505 |
+
torch.cuda.set_device(self.encoder.first_device)
|
506 |
+
self.lm_head = self.lm_head.to(self.encoder.first_device)
|
507 |
+
sequence_output = sequence_output.to(self.lm_head.weight.device)
|
508 |
+
|
509 |
+
if self.config.tie_word_embeddings:
|
510 |
+
# Rescale output before projecting on vocab
|
511 |
+
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
|
512 |
+
sequence_output = sequence_output * (self.model_dim**-0.5)
|
513 |
+
|
514 |
+
lm_logits = self.lm_head(sequence_output)
|
515 |
+
|
516 |
+
loss = None
|
517 |
+
if labels is not None:
|
518 |
+
loss_fct = CrossEntropyLoss(ignore_index=-100)
|
519 |
+
# move labels to correct device to enable PP
|
520 |
+
labels = labels.to(lm_logits.device)
|
521 |
+
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
|
522 |
+
# TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
|
523 |
+
|
524 |
+
if not return_dict:
|
525 |
+
output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
|
526 |
+
return ((loss,) + output) if loss is not None else output
|
527 |
+
|
528 |
+
return Seq2SeqLMOutput(
|
529 |
+
loss=loss,
|
530 |
+
logits=lm_logits,
|
531 |
+
past_key_values=decoder_outputs.past_key_values,
|
532 |
+
decoder_hidden_states=decoder_outputs.hidden_states,
|
533 |
+
decoder_attentions=decoder_outputs.attentions,
|
534 |
+
cross_attentions=decoder_outputs.cross_attentions,
|
535 |
+
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
536 |
+
encoder_hidden_states=encoder_outputs.hidden_states,
|
537 |
+
encoder_attentions=encoder_outputs.attentions,
|
538 |
+
)
|
539 |
+
|
540 |
+
# Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.prepare_inputs_for_generation
|
541 |
+
def prepare_inputs_for_generation(
|
542 |
+
self,
|
543 |
+
input_ids,
|
544 |
+
past_key_values=None,
|
545 |
+
attention_mask=None,
|
546 |
+
head_mask=None,
|
547 |
+
decoder_head_mask=None,
|
548 |
+
decoder_attention_mask=None,
|
549 |
+
cross_attn_head_mask=None,
|
550 |
+
use_cache=None,
|
551 |
+
encoder_outputs=None,
|
552 |
+
**kwargs,
|
553 |
+
):
|
554 |
+
# cut decoder_input_ids if past_key_values is used
|
555 |
+
if past_key_values is not None:
|
556 |
+
past_length = past_key_values[0][0].shape[2]
|
557 |
+
|
558 |
+
# Some generation methods already pass only the last input ID
|
559 |
+
if input_ids.shape[1] > past_length:
|
560 |
+
remove_prefix_length = past_length
|
561 |
+
else:
|
562 |
+
# Default to old behavior: keep only final ID
|
563 |
+
remove_prefix_length = input_ids.shape[1] - 1
|
564 |
+
|
565 |
+
input_ids = input_ids[:, remove_prefix_length:]
|
566 |
+
|
567 |
+
return {
|
568 |
+
"decoder_input_ids": input_ids,
|
569 |
+
"past_key_values": past_key_values,
|
570 |
+
"encoder_outputs": encoder_outputs,
|
571 |
+
"attention_mask": attention_mask,
|
572 |
+
"head_mask": head_mask,
|
573 |
+
"decoder_head_mask": decoder_head_mask,
|
574 |
+
"decoder_attention_mask": decoder_attention_mask,
|
575 |
+
"cross_attn_head_mask": cross_attn_head_mask,
|
576 |
+
"use_cache": use_cache,
|
577 |
+
}
|
578 |
+
|
579 |
+
# Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.prepare_decoder_input_ids_from_labels
|
580 |
+
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
|
581 |
+
return self._shift_right(labels)
|
582 |
+
|
583 |
+
# Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration._reorder_cache
|
584 |
+
def _reorder_cache(self, past_key_values, beam_idx):
|
585 |
+
# if decoder past is not included in output
|
586 |
+
# speedy decoding is disabled and no need to reorder
|
587 |
+
if past_key_values is None:
|
588 |
+
logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
|
589 |
+
return past_key_values
|
590 |
+
|
591 |
+
reordered_decoder_past = ()
|
592 |
+
for layer_past_states in past_key_values:
|
593 |
+
# get the correct batch idx from layer past batch dim
|
594 |
+
# batch dim of `past` is at 2nd position
|
595 |
+
reordered_layer_past_states = ()
|
596 |
+
for layer_past_state in layer_past_states:
|
597 |
+
# need to set correct `past` for each of the four key / value states
|
598 |
+
reordered_layer_past_states = reordered_layer_past_states + (
|
599 |
+
layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)),
|
600 |
+
)
|
601 |
+
|
602 |
+
if reordered_layer_past_states[0].shape != layer_past_states[0].shape:
|
603 |
+
raise ValueError(
|
604 |
+
f"reordered_layer_past_states[0] shape {reordered_layer_past_states[0].shape} and layer_past_states[0] shape {layer_past_states[0].shape} mismatched"
|
605 |
+
)
|
606 |
+
if len(reordered_layer_past_states) != len(layer_past_states):
|
607 |
+
raise ValueError(
|
608 |
+
f"length of reordered_layer_past_states {len(reordered_layer_past_states)} and length of layer_past_states {len(layer_past_states)} mismatched"
|
609 |
+
)
|
610 |
+
|
611 |
+
reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
|
612 |
+
return reordered_decoder_past
|
models/IUPAC2SMILES/config.json
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"MT5ForConditionalGeneration"
|
4 |
+
],
|
5 |
+
"classifier_dropout": 0.0,
|
6 |
+
"d_ff": 512,
|
7 |
+
"d_kv": 64,
|
8 |
+
"d_model": 256,
|
9 |
+
"decoder_start_token_id": 2,
|
10 |
+
"decoder_vocab_size": 137,
|
11 |
+
"dense_act_fn": "gelu_new",
|
12 |
+
"dropout_rate": 0.1,
|
13 |
+
"encoder_vocab_size": 822,
|
14 |
+
"eos_token_id": 1,
|
15 |
+
"feed_forward_proj": "gated-gelu",
|
16 |
+
"initializer_factor": 1.0,
|
17 |
+
"is_encoder_decoder": true,
|
18 |
+
"is_gated_act": true,
|
19 |
+
"layer_norm_epsilon": 1e-06,
|
20 |
+
"model_type": "mt5",
|
21 |
+
"num_decoder_layers": 4,
|
22 |
+
"num_heads": 3,
|
23 |
+
"num_layers": 4,
|
24 |
+
"pad_token_id": 0,
|
25 |
+
"relative_attention_max_distance": 128,
|
26 |
+
"relative_attention_num_buckets": 32,
|
27 |
+
"shared_embedding": false,
|
28 |
+
"tie_word_embeddings": false,
|
29 |
+
"tokenizer_class": "T5Tokenizer",
|
30 |
+
"torch_dtype": "float32",
|
31 |
+
"transformers_version": "4.37.1",
|
32 |
+
"use_cache": true
|
33 |
+
}
|
models/IUPAC2SMILES/generation_config.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_from_model_config": true,
|
3 |
+
"decoder_start_token_id": 2,
|
4 |
+
"eos_token_id": 1,
|
5 |
+
"pad_token_id": 0,
|
6 |
+
"transformers_version": "4.37.1"
|
7 |
+
}
|
models/IUPAC2SMILES/model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d1f38994ec986388a2f099652139d6a05b5981fb57bdf62361d4614f84ca07ed
|
3 |
+
size 23177168
|
models/SMILES2IUPAC/config.json
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"MT5ForConditionalGeneration"
|
4 |
+
],
|
5 |
+
"classifier_dropout": 0.0,
|
6 |
+
"d_ff": 512,
|
7 |
+
"d_kv": 64,
|
8 |
+
"d_model": 256,
|
9 |
+
"decoder_start_token_id": 2,
|
10 |
+
"decoder_vocab_size": 822,
|
11 |
+
"dense_act_fn": "gelu_new",
|
12 |
+
"dropout_rate": 0.1,
|
13 |
+
"encoder_vocab_size": 137,
|
14 |
+
"eos_token_id": 1,
|
15 |
+
"feed_forward_proj": "gated-gelu",
|
16 |
+
"initializer_factor": 1.0,
|
17 |
+
"is_encoder_decoder": true,
|
18 |
+
"is_gated_act": true,
|
19 |
+
"layer_norm_epsilon": 1e-06,
|
20 |
+
"model_type": "mt5",
|
21 |
+
"num_decoder_layers": 4,
|
22 |
+
"num_heads": 3,
|
23 |
+
"num_layers": 4,
|
24 |
+
"pad_token_id": 0,
|
25 |
+
"relative_attention_max_distance": 128,
|
26 |
+
"relative_attention_num_buckets": 32,
|
27 |
+
"shared_embedding": false,
|
28 |
+
"tie_word_embeddings": false,
|
29 |
+
"tokenizer_class": "T5Tokenizer",
|
30 |
+
"torch_dtype": "float32",
|
31 |
+
"transformers_version": "4.37.1",
|
32 |
+
"use_cache": true
|
33 |
+
}
|
models/SMILES2IUPAC/generation_config.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_from_model_config": true,
|
3 |
+
"decoder_start_token_id": 2,
|
4 |
+
"eos_token_id": 1,
|
5 |
+
"pad_token_id": 0,
|
6 |
+
"transformers_version": "4.37.1"
|
7 |
+
}
|
models/SMILES2IUPAC/model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4307a50d6b192a06bb81552d7cd6bcf6ac7ea6bb21d72ca4755e28d7d28655d2
|
3 |
+
size 23878608
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
transformers
|
2 |
+
torch
|
3 |
+
rdkit
|
test.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from gradio_client import Client
|
2 |
+
|
3 |
+
client = Client("https://knowledgator-chemicalconverters.hf.space/--replicas/ucig0/")
|
4 |
+
result = client.predict(
|
5 |
+
"CCO", # str in 'Enter your chemical name' Textbox component
|
6 |
+
"SMILES2IUPAC", # Literal['SMILES2IUPAC', 'IUPAC2SMILES', 'IUPAC style prediction'] in 'Choose method to convert chemical names' Radio component
|
7 |
+
"BASE", # Literal['BASE', 'SYSTEMATIC', 'TRADITIONAL'] in 'If SMILES to IUPAC, choose desired IUPAC style' Radio component
|
8 |
+
True, # bool in 'Validate with molecular similarity' Checkbox component
|
9 |
+
True, # bool in 'Plot molecule' Checkbox component
|
10 |
+
api_name="/predict"
|
11 |
+
)
|
12 |
+
print(result)
|
utils/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .main_model import ChemicalConverter
|
2 |
+
from .rdkit_utils import validate_smiles2iupac, plot_mol
|
utils/main_model.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modeling import MT5ForConditionalGeneration
|
2 |
+
from transformers import AutoTokenizer
|
3 |
+
import os
|
4 |
+
|
5 |
+
|
6 |
+
class ChemicalConverter:
|
7 |
+
def __init__(self, mode: str):
|
8 |
+
self.mode = mode
|
9 |
+
model_directory = os.path.abspath("models")
|
10 |
+
model_path = os.path.join(model_directory, mode)
|
11 |
+
if not os.path.exists(model_path):
|
12 |
+
raise ValueError(f"Model path does not exist: {model_path}")
|
13 |
+
self.model = MT5ForConditionalGeneration.from_pretrained(model_path)
|
14 |
+
self.smiles_tokenizer = AutoTokenizer.from_pretrained("BioMike/smiles")
|
15 |
+
self.iupac_tokenizer = AutoTokenizer.from_pretrained("BioMike/iupac")
|
16 |
+
self.smiles_max_len = 128
|
17 |
+
self.iupac_max_len = 156
|
18 |
+
|
19 |
+
def convert(self, input):
|
20 |
+
if self.mode == "SMILES2IUPAC":
|
21 |
+
tokenizer = self.smiles_tokenizer
|
22 |
+
reverse_tokenizer = self.iupac_tokenizer
|
23 |
+
max_length = self.smiles_max_len
|
24 |
+
else:
|
25 |
+
tokenizer = self.iupac_tokenizer
|
26 |
+
reverse_tokenizer = self.smiles_tokenizer
|
27 |
+
max_length = self.iupac_max_len
|
28 |
+
|
29 |
+
encoding = tokenizer(input,
|
30 |
+
return_tensors='pt',
|
31 |
+
padding="max_length",
|
32 |
+
truncation=True,
|
33 |
+
max_length=max_length)
|
34 |
+
# Move the input tensor to GPU
|
35 |
+
encoding = {key: value.to(self.model.device) for key, value in encoding.items()}
|
36 |
+
|
37 |
+
# Generate names
|
38 |
+
output = self.model.generate(input_ids=encoding['input_ids'],
|
39 |
+
attention_mask=encoding['attention_mask'],
|
40 |
+
max_new_tokens=156,
|
41 |
+
num_beams=1,
|
42 |
+
num_return_sequences=1)
|
43 |
+
|
44 |
+
# Decode names
|
45 |
+
output = [reverse_tokenizer.decode(ids, skip_special_tokens=True) for ids in output]
|
46 |
+
|
47 |
+
return output[0]
|
utils/rdkit_utils.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from rdkit import DataStructs, Chem
|
2 |
+
from rdkit.Chem import AllChem
|
3 |
+
from rdkit.Chem import Draw
|
4 |
+
from PIL import Image
|
5 |
+
import io
|
6 |
+
from .main_model import ChemicalConverter
|
7 |
+
|
8 |
+
def validate_smiles2iupac(input_smiles, predicted_iupac):
|
9 |
+
converter = ChemicalConverter(mode="IUPAC2SMILES")
|
10 |
+
predicted_smiles = converter.convert(predicted_iupac)
|
11 |
+
|
12 |
+
ms = [Chem.MolFromSmiles(input_smiles), Chem.MolFromSmiles(predicted_smiles[6:])]
|
13 |
+
|
14 |
+
if None in ms:
|
15 |
+
return None
|
16 |
+
|
17 |
+
fpgen = AllChem.GetRDKitFPGenerator()
|
18 |
+
fps = [fpgen.GetFingerprint(x) for x in ms]
|
19 |
+
|
20 |
+
return DataStructs.TanimotoSimilarity(fps[0], fps[1])
|
21 |
+
|
22 |
+
def plot_mol(smiles):
|
23 |
+
# Convert the SMILES string to an RDKit molecule object
|
24 |
+
mol = Chem.MolFromSmiles(smiles)
|
25 |
+
|
26 |
+
# Use RDKit to draw the molecule to an image, with original intended size
|
27 |
+
img = Draw.MolToImage(mol, size=(185, 185))
|
28 |
+
|
29 |
+
# Create a new, blank image with the desired final size (800x190 pixels) with a white background
|
30 |
+
final_img = Image.new('RGB', (890, 185), 'white')
|
31 |
+
|
32 |
+
# Calculate the position to paste the original image onto the blank image to keep it centered
|
33 |
+
left = (890 - 185) // 2
|
34 |
+
top = (185 - 185) // 2 # This will be zero in this case but included for clarity
|
35 |
+
|
36 |
+
# Paste the original image onto the blank image
|
37 |
+
final_img.paste(img, (left, top))
|
38 |
+
|
39 |
+
return final_img
|