File size: 1,057 Bytes
aad6030 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
from transformers import pipeline, Text2TextGenerationPipeline
model_name = "test-BB2"
model = pipeline(
"text2text-generation",
model=model_name,
tokenizer=model_name,
)
class MyText2TextGenerationPipeline(Text2TextGenerationPipeline):
def __init__(self, *args, default_text=None, **kwargs):
super().__init__(*args, **kwargs)
self.default_text = default_text
def __call__(self, *args, **kwargs):
if "prompt" not in kwargs and self.default_text is not None:
kwargs["prompt"] = self.default_text
return super().__call__(*args, **kwargs)
def generate_text(input_text, default_text="Enter your input text here"):
generator = MyText2TextGenerationPipeline(
model=model,
tokenizer=model.tokenizer,
default_text=default_text
)
return generator(input_text, max_length=100)
# Example usage:
input_text = "The quick brown fox jumps over the lazy dog."
generated_text = generate_text(input_text, default_text="Enter a new input text here")
print(generated_text)
|