DEADLOCK007X commited on
Commit
87e26da
·
verified ·
1 Parent(s): ab7ddd9

Update tinyllama_inference.py

Browse files
Files changed (1) hide show
  1. tinyllama_inference.py +7 -8
tinyllama_inference.py CHANGED
@@ -1,5 +1,4 @@
1
  import json
2
- import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
 
5
  def load_model():
@@ -20,7 +19,7 @@ Solution: "{code}"
20
  Return ONLY valid JSON: {{"stars": number, "feedback": string}}
21
  Do not include any extra text outside the JSON.
22
  """
23
- # Load the model and tokenizer.
24
  tokenizer, model = load_model()
25
  inputs = tokenizer(prompt, return_tensors="pt")
26
  outputs = model.generate(**inputs, max_new_tokens=150)
@@ -31,13 +30,13 @@ Do not include any extra text outside the JSON.
31
  result = {"stars": 0, "feedback": "Evaluation failed. Unable to parse AI response."}
32
  return result
33
 
34
- # For direct testing from the command line
35
  if __name__ == "__main__":
36
  import sys
37
  if len(sys.argv) < 3:
38
  print(json.dumps({"error": "Please provide a question and code as arguments"}))
39
- sys.exit(1)
40
- question = sys.argv[1]
41
- code = sys.argv[2]
42
- result = evaluate_code(question, code)
43
- print(json.dumps(result))
 
1
  import json
 
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
 
4
  def load_model():
 
19
  Return ONLY valid JSON: {{"stars": number, "feedback": string}}
20
  Do not include any extra text outside the JSON.
21
  """
22
+ # Load model and tokenizer (for simplicity, we load them per request; consider caching for performance)
23
  tokenizer, model = load_model()
24
  inputs = tokenizer(prompt, return_tensors="pt")
25
  outputs = model.generate(**inputs, max_new_tokens=150)
 
30
  result = {"stars": 0, "feedback": "Evaluation failed. Unable to parse AI response."}
31
  return result
32
 
33
+ # For direct testing from the command line:
34
  if __name__ == "__main__":
35
  import sys
36
  if len(sys.argv) < 3:
37
  print(json.dumps({"error": "Please provide a question and code as arguments"}))
38
+ else:
39
+ question = sys.argv[1]
40
+ code = sys.argv[2]
41
+ result = evaluate_code(question, code)
42
+ print(json.dumps(result))