|
from transformers import pipeline, Text2TextGenerationPipeline |
|
|
|
model_name = "test-BB2" |
|
model = pipeline( |
|
"text2text-generation", |
|
model=model_name, |
|
tokenizer=model_name, |
|
) |
|
|
|
class MyText2TextGenerationPipeline(Text2TextGenerationPipeline): |
|
def __init__(self, *args, default_text=None, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.default_text = default_text |
|
|
|
def __call__(self, *args, **kwargs): |
|
if "prompt" not in kwargs and self.default_text is not None: |
|
kwargs["prompt"] = self.default_text |
|
return super().__call__(*args, **kwargs) |
|
|
|
def generate_text(input_text, default_text="Enter your input text here"): |
|
generator = MyText2TextGenerationPipeline( |
|
model=model, |
|
tokenizer=model.tokenizer, |
|
default_text=default_text |
|
) |
|
return generator(input_text, max_length=100) |
|
|
|
|
|
input_text = "The quick brown fox jumps over the lazy dog." |
|
generated_text = generate_text(input_text, default_text="Enter a new input text here") |
|
print(generated_text) |
|
|