alfiannajih commited on
Commit
06fa6b1
·
verified ·
1 Parent(s): db7044b

Update g_retriever.py

Browse files
Files changed (1) hide show
  1. g_retriever.py +1 -1
g_retriever.py CHANGED
@@ -42,7 +42,7 @@ class GRetrieverModel(LlamaForCausalLM):
42
  # mean pooling
43
  g_embeds = global_mean_pool(n_embeds, graph.batch.to(n_embeds.device))
44
 
45
- return g_embeds.to(model.device)
46
 
47
  @wraps(LlamaForCausalLM.forward)
48
  def forward(
 
42
  # mean pooling
43
  g_embeds = global_mean_pool(n_embeds, graph.batch.to(n_embeds.device))
44
 
45
+ return g_embeds.to(self.model.device)
46
 
47
  @wraps(LlamaForCausalLM.forward)
48
  def forward(