Text2Text Generation
Transformers
PyTorch
Safetensors
English
encoder-decoder
medical
Inference Endpoints
Waleed-bin-Qamar commited on
Commit
aad6030
1 Parent(s): a6760d1

Create inference.py

Browse files
Files changed (1) hide show
  1. inference.py +31 -0
inference.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import pipeline, Text2TextGenerationPipeline
2
+
3
+ model_name = "test-BB2"
4
+ model = pipeline(
5
+ "text2text-generation",
6
+ model=model_name,
7
+ tokenizer=model_name,
8
+ )
9
+
10
+ class MyText2TextGenerationPipeline(Text2TextGenerationPipeline):
11
+ def __init__(self, *args, default_text=None, **kwargs):
12
+ super().__init__(*args, **kwargs)
13
+ self.default_text = default_text
14
+
15
+ def __call__(self, *args, **kwargs):
16
+ if "prompt" not in kwargs and self.default_text is not None:
17
+ kwargs["prompt"] = self.default_text
18
+ return super().__call__(*args, **kwargs)
19
+
20
+ def generate_text(input_text, default_text="Enter your input text here"):
21
+ generator = MyText2TextGenerationPipeline(
22
+ model=model,
23
+ tokenizer=model.tokenizer,
24
+ default_text=default_text
25
+ )
26
+ return generator(input_text, max_length=100)
27
+
28
+ # Example usage:
29
+ input_text = "The quick brown fox jumps over the lazy dog."
30
+ generated_text = generate_text(input_text, default_text="Enter a new input text here")
31
+ print(generated_text)