Chris4K commited on
Commit
85f410a
·
verified ·
1 Parent(s): 33239be

Update services/strategy.py

Browse files
Files changed (1) hide show
  1. services/strategy.py +8 -1
services/strategy.py CHANGED
@@ -54,6 +54,7 @@ class MajorityVotingStrategy(GenerationStrategy):
54
 
55
 
56
  class BestOfN(GenerationStrategy):
 
57
  def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
58
  scored_outputs = []
59
  for _ in range(num_samples):
@@ -65,9 +66,14 @@ class BestOfN(GenerationStrategy):
65
  response = generator.tokenizer.decode(output[0], skip_special_tokens=True)
66
 
67
  # Tokenize the response for scoring with the PRM model
 
68
  response_inputs = generator.tokenizer(response, return_tensors="pt").to(generator.device)
 
 
69
  prm_output = generator.prm_model(**response_inputs) # Pass the inputs correctly to the model
70
- score = prm_output.logits.mean().item()
 
 
71
 
72
  # Append the response and its score
73
  scored_outputs.append((response, score))
@@ -76,6 +82,7 @@ class BestOfN(GenerationStrategy):
76
  return max(scored_outputs, key=lambda x: x[1])[0]
77
 
78
 
 
79
  class BeamSearch(GenerationStrategy):
80
  def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
81
  input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device)
 
54
 
55
 
56
  class BestOfN(GenerationStrategy):
57
+ @observe()
58
  def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
59
  scored_outputs = []
60
  for _ in range(num_samples):
 
66
  response = generator.tokenizer.decode(output[0], skip_special_tokens=True)
67
 
68
  # Tokenize the response for scoring with the PRM model
69
+ #TODO use the real tokenizer from generator
70
  response_inputs = generator.tokenizer(response, return_tensors="pt").to(generator.device)
71
+
72
+ # Pass the response inputs correctly to the PRM model
73
  prm_output = generator.prm_model(**response_inputs) # Pass the inputs correctly to the model
74
+
75
+ # Check the expected output structure for prm_model and use it accordingly
76
+ score = prm_output.logits.mean().item() if hasattr(prm_output, 'logits') else 0.0
77
 
78
  # Append the response and its score
79
  scored_outputs.append((response, score))
 
82
  return max(scored_outputs, key=lambda x: x[1])[0]
83
 
84
 
85
+
86
  class BeamSearch(GenerationStrategy):
87
  def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
88
  input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device)