streetyogi commited on
Commit
7d22c1d
·
1 Parent(s): 3ecf051

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +24 -51
main.py CHANGED
@@ -1,56 +1,29 @@
1
- from fastapi import FastAPI
2
- from fastapi.staticfiles import StaticFiles
3
- from fastapi.responses import FileResponse
4
- from transformers import BertTokenizer, BertForMaskedLM, Trainer, TrainingArguments
5
 
6
  app = FastAPI()
7
 
8
- # Initialize the tokenizer and model
9
- tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
10
- model = BertForMaskedLM.from_pretrained("bert-base-uncased")
11
 
12
- # Prepare the training data
13
  with open("cyberpunk_lore.txt", "r") as f:
14
- train_data = f.read()
15
- train_data = train_data.split("\n")
16
- train_data = [tokenizer.encode(text, return_tensors="pt") for text in train_data]
17
-
18
- # Define the training arguments
19
- training_args = TrainingArguments(
20
- output_dir="./results",
21
- per_device_train_batch_size=16,
22
- save_steps=10_000,
23
- save_total_limit=2,
24
- )
25
-
26
- # Create the trainer
27
- trainer = Trainer(
28
- model=model,
29
- args=training_args,
30
- train_dataset=train_data,
31
- eval_dataset=train_data,
32
- )
33
-
34
- # Start the training
35
- trainer.train()
36
-
37
- # Save the fine-tuned model
38
- trainer.save_model('./results')
39
-
40
- # Load the fine-tuned model
41
- model = trainer.get_model()
42
-
43
- # Create the inference endpoint
44
- @app.post("/infer")
45
- def infer(input: str):
46
- input_ids = tokenizer.encode(input, return_tensors="pt")
47
- output = model(input_ids)[0]
48
- return {"output": output}
49
-
50
- @app.get("/")
51
- def index() -> FileResponse:
52
- return FileResponse(path="/app/static/index.html", media_type="text/html")
53
-
54
- @app.get("/")
55
- def index() -> FileResponse:
56
- return FileResponse(path="/app/static/index.html", media_type="text/html")
 
1
+ import torch
2
+ from transformers import RobertaForMaskedLM, RobertaTokenizer
3
+ from fastapi import FastAPI, HTTPException
 
4
 
5
  app = FastAPI()
6
 
7
+ # Load the pre-trained model and tokenizer
8
+ model = RobertaForMaskedLM.from_pretrained('roberta-base')
9
+ tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
10
 
11
+ # Load your dataset, in this case "cyberpunk_lore.txt"
12
  with open("cyberpunk_lore.txt", "r") as f:
13
+ dataset = f.read()
14
+
15
+ # Train the model on your dataset
16
+ input_ids = torch.tensor([tokenizer.encode(dataset, add_special_tokens=True)])
17
+ model.train()
18
+ model.zero_grad()
19
+ outputs = model(input_ids, labels=input_ids)
20
+ loss, logits = outputs[:2]
21
+ loss.backward()
22
+
23
+ # Serve the model via FastAPI
24
+ @app.post("/predict")
25
+ def predict(prompt: str):
26
+ input_ids = torch.tensor([tokenizer.encode(prompt, add_special_tokens=True)])
27
+ outputs = model(input_ids)
28
+ generated_text = tokenizer.decode(outputs[0].argmax(dim=1).tolist()[0])
29
+ return {"generated_text": generated_text}