File size: 984 Bytes
5cc1d21
0927bb7
c09cb0e
0927bb7
5cc1d21
c09cb0e
5cc1d21
c09cb0e
 
88533e2
0927bb7
 
5cc1d21
 
 
 
c09cb0e
 
10fe91a
5cc1d21
 
10fe91a
5cc1d21
 
 
 
0927bb7
c09cb0e
5cc1d21
0927bb7
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline, logging


checkpoint = "Salesforce/codet5p-770m"

tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint, cache_dir="models/")


def code_gen(text):
    logging.set_verbosity(logging.CRITICAL)

    print("*** Pipeline:")
    pipe = pipeline(
        model=checkpoint,
        # tokenizer=tokenizer,
        max_new_tokens=124,
        temperature=0.7,
        top_p=0.95,
        device= "cuda" if torch.cuda.is_available() else "cpu",
        repetition_penalty=1.15
    )

    response = pipe(text)
    print(response)

    return response[0]['generated_text']


iface = gr.Interface(fn=code_gen,
                     inputs=gr.inputs.Textbox(
                         label="Input Source Code"),
                     outputs="text",
                     title="Code Generation")

iface.launch()