samkeet commited on
Commit
d4050a7
·
verified ·
1 Parent(s): 13e93c3

Upload GPT124MTextGenerationPipeline

Browse files
Files changed (2) hide show
  1. config.json +15 -0
  2. pipeline_gpt.py +0 -4
config.json CHANGED
@@ -7,6 +7,21 @@
7
  "AutoModelForCausalLM": "modeling_gpt.GPTModelForTextGeneration"
8
  },
9
  "block_size": 1024,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  "model_type": "custom_gpt",
11
  "n_embd": 768,
12
  "n_head": 12,
 
7
  "AutoModelForCausalLM": "modeling_gpt.GPTModelForTextGeneration"
8
  },
9
  "block_size": 1024,
10
+ "custom_pipelines": {
11
+ "text-generation": {
12
+ "default": {
13
+ "model": {
14
+ "pt": "samkeet/GPT_124M"
15
+ }
16
+ },
17
+ "impl": "pipeline_gpt.GPT124MTextGenerationPipeline",
18
+ "pt": [
19
+ "AutoModelForCausalLM"
20
+ ],
21
+ "tf": [],
22
+ "type": "text"
23
+ }
24
+ },
25
  "model_type": "custom_gpt",
26
  "n_embd": 768,
27
  "n_head": 12,
pipeline_gpt.py CHANGED
@@ -49,10 +49,6 @@ class GPT124MTextGenerationPipeline(Pipeline):
49
  Forwards the tokenized input to the model's generate method.
50
  """
51
 
52
- # Access the actual GPT model inside GPTModelForTextGeneration
53
- if isinstance(self.model, GPTModelForTextGeneration):
54
- self.model = self.model.model
55
-
56
  return self.model.generate(**model_inputs, **forward_kwargs)
57
 
58
  def postprocess(self, model_output, **postprocess_kwargs):
 
49
  Forwards the tokenized input to the model's generate method.
50
  """
51
 
 
 
 
 
52
  return self.model.generate(**model_inputs, **forward_kwargs)
53
 
54
  def postprocess(self, model_output, **postprocess_kwargs):