bananabot's picture
Update app.py
e20cae0
raw
history blame
1.23 kB
import torch
import pandas as pd
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification, TrainingArguments
import gradio as gr
from gradio.mix import Parallel, Series
#import torch.nn.functional as F
from datasets import load_dataset
dataset = load_dataset("bananabot/engMollywoodSummaries")
dataset
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "EleutherAI/gpt-neo-125M"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
max_length=123
input_txt = "This malayalam movie is about"
n_steps = 8
input_ids = tokenizer(input_txt, return_tensors="pt")["input_ids"].to(device)
output = model.generate(input_ids, max_length=max_length, num_beams=5, do_sample=True, no_repeat_ngram_size=2, temperature=1.37, top_k=69, top_p=0.96)
print(tokenizer.decode(output[0]))
generator = gr.Interface.load("models/EleutherAI/gpt-neo-125M")
translator = gr.Interface.load("models/Helsinki-NLP/opus-mt-en-ml")
gr.Series(generator, translator, inputs=gr.inputs.Textbox(lines=13, label="Input Text")).launch() # this demo generates text, then translates it to Malayalam, and outputs the final result.