bananabot commited on
Commit
193ebc9
1 Parent(s): d11a12a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -3
app.py CHANGED
@@ -1,13 +1,37 @@
1
  import torch
2
  import pandas as pd
3
- from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification, TrainingArguments
4
  import gradio as gr
5
  from gradio.mix import Parallel, Series
6
  #import torch.nn.functional as F
7
  from aitextgen import aitextgen
 
8
  from datasets import load_dataset
9
  dataset = load_dataset("bananabot/engMollywoodSummaries")
10
- dataset
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
 
@@ -15,7 +39,7 @@ ai = aitextgen(model="EleutherAI/gpt-neo-1.3B")
15
 
16
  #model_name = "EleutherAI/gpt-neo-125M"
17
  #tokenizer = AutoTokenizer.from_pretrained(model_name)
18
- #model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
19
 
20
  #max_length=123
21
  #input_txt = "This malayalam movie is about"
 
1
  import torch
2
  import pandas as pd
3
+ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification, TrainingArguments, Trainer
4
  import gradio as gr
5
  from gradio.mix import Parallel, Series
6
  #import torch.nn.functional as F
7
  from aitextgen import aitextgen
8
+
9
  from datasets import load_dataset
10
  dataset = load_dataset("bananabot/engMollywoodSummaries")
11
+ tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-1.3B")
12
+
13
+ def tokenize_function(examples):
14
+ return tokenizer(examples["text"], padding="max_length", truncation=True)
15
+ tokenized_datasets = dataset.map(tokenize_function, batched=True)
16
+
17
+ model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B").to(device)
18
+ training_args = TrainingArguments(output_dir="test_trainer")
19
+ small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(1000))
20
+ small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(1000))
21
+ def compute_metrics(eval_pred):
22
+ logits, labels = eval_pred
23
+ predictions = np.argmax(logits, axis=-1)
24
+ return metric.compute(predictions=predictions, references=labels)
25
+
26
+ trainer = Trainer(
27
+ model=model,
28
+ args=training_args,
29
+ train_dataset=small_train_dataset,
30
+ eval_dataset=small_eval_dataset,
31
+ compute_metrics=compute_metrics,
32
+ )
33
+
34
+ trainer.train()
35
 
36
  device = "cuda" if torch.cuda.is_available() else "cpu"
37
 
 
39
 
40
  #model_name = "EleutherAI/gpt-neo-125M"
41
  #tokenizer = AutoTokenizer.from_pretrained(model_name)
42
+
43
 
44
  #max_length=123
45
  #input_txt = "This malayalam movie is about"