File size: 3,963 Bytes
651bb25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import gradio as gr
import torch
from transformers import PegasusForConditionalGeneration, PegasusTokenizer
import re
import os

def load_model():
    """Load the model from local storage"""
    torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Using device: {torch_device}")

    # Load tokenizer and model from local directory
    tokenizer = PegasusTokenizer.from_pretrained('./models')
    model = PegasusForConditionalGeneration.from_pretrained('./models').to(torch_device)
    return tokenizer, model, torch_device

def split_into_paragraphs(text):
    """Split text into paragraphs while preserving empty lines."""
    paragraphs = text.split('\n\n')
    return [p.strip() for p in paragraphs if p.strip()]

def split_into_sentences(paragraph):
    """Split paragraph into sentences using regex."""
    sentences = re.split(r'(?<=[.!?])\s+', paragraph)
    return [s.strip() for s in sentences if s.strip()]

def get_response(input_text, num_return_sequences, tokenizer, model, torch_device):
    batch = tokenizer.prepare_seq2seq_batch(
        [input_text],
        truncation=True,
        padding='longest',
        max_length=80,
        return_tensors="pt"
    ).to(torch_device)

    translated = model.generate(
        **batch,
        num_beams=10,
        num_return_sequences=num_return_sequences,
        temperature=1.0,
        repetition_penalty=2.8,
        length_penalty=1.2,
        max_length=80,
        min_length=5,
        no_repeat_ngram_size=3
    )

    tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True)
    return tgt_text[0]

def get_response_from_text(context, tokenizer, model, torch_device):
    """Process entire text while preserving paragraph structure."""
    paragraphs = split_into_paragraphs(context)
    paraphrased_paragraphs = []

    for paragraph in paragraphs:
        sentences = split_into_sentences(paragraph)
        paraphrased_sentences = []

        for sentence in sentences:
            if len(sentence.split()) < 3:
                paraphrased_sentences.append(sentence)
                continue

            try:
                paraphrased = get_response(sentence, 1, tokenizer, model, torch_device)
                if not any(phrase in paraphrased.lower() for phrase in ['it\'s like', 'in other words']):
                    paraphrased_sentences.append(paraphrased)
                else:
                    paraphrased_sentences.append(sentence)
            except Exception as e:
                print(f"Error processing sentence: {e}")
                paraphrased_sentences.append(sentence)

        paraphrased_paragraphs.append(' '.join(paraphrased_sentences))

    return '\n\n'.join(paraphrased_paragraphs)

def create_interface():
    """Create and configure the Gradio interface"""
    # Load model and tokenizer
    tokenizer, model, torch_device = load_model()

    def greet(context):
        return get_response_from_text(context, tokenizer, model, torch_device)

    # Create interface with improved styling
    iface = gr.Interface(
        fn=greet,
        inputs=gr.Textbox(
            lines=15,
            label="Input Text",
            placeholder="Enter your text here...",
            elem_classes="input-text"
        ),
        outputs=gr.Textbox(
            lines=15,
            label="Paraphrased Text",
            elem_classes="output-text"
        ),
        title="Advanced Text Paraphraser",
        description="Enter text to generate a high-quality paraphrased version while maintaining paragraph structure.",
        theme="default",
        css="""
            .input-text, .output-text {
                font-size: 16px !important;
                font-family: Arial, sans-serif !important;
                min-height: 300px !important;
            }
        """
    )
    return iface

if __name__ == "__main__":
    # Create and launch the interface
    interface = create_interface()
    interface.launch()