Nathan Slaughter commited on
Commit
8428312
1 Parent(s): c7db8fe
.gitignore CHANGED
@@ -1 +1,3 @@
1
  __pycache__
 
 
 
1
  __pycache__
2
+ .coverage
3
+ *.log
app/interface.py CHANGED
@@ -1,4 +1,5 @@
1
  import gradio as gr
 
2
  from .pipeline import Pipeline
3
  from .processing import process_file, process_text_input
4
 
 
1
  import gradio as gr
2
+
3
  from .pipeline import Pipeline
4
  from .processing import process_file, process_text_input
5
 
app/models.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import csv
3
+ from io import StringIO
4
+
5
+ from pydantic import BaseModel, validator, ValidationError
6
+
7
+ class Card(BaseModel):
8
+ question: str
9
+ answer: str
10
+
11
+ class Message(BaseModel):
12
+ role: str
13
+ content: list[Card]
14
+
15
+ @validator('content', pre=True)
16
+ def parse_content(cls, v):
17
+ if isinstance(v, str):
18
+ try:
19
+ content_list = json.loads(v)
20
+ return content_list
21
+ except json.JSONDecodeError as e:
22
+ raise ValueError(f"Error decoding 'content' JSON: {e}") from e
23
+ return v
24
+
25
+ def content_to_json(self) -> str:
26
+ return json.dumps([card.dict() for card in self.content], indent=2)
27
+
28
+ def content_to_csv(self) -> str:
29
+ """
30
+ Converts the content of the Message instance into a CSV string.
31
+ """
32
+ output = StringIO()
33
+ # Step 2: Create a CSV writer - windows style is the default, so set Unix-style line endings
34
+ writer = csv.writer(output, lineterminator='\n')
35
+ writer.writerow(["Question", "Answer"])
36
+ for card in self.content:
37
+ writer.writerow([card.question, card.answer])
38
+ csv_content = output.getvalue()
39
+ output.close()
40
+ return csv_content
41
+
42
+ class PydanticEncoder(json.JSONEncoder):
43
+ def default(self, obj):
44
+ if isinstance(obj, BaseModel):
45
+ return obj.dict()
46
+ return super().default(obj)
app/pipeline.py CHANGED
@@ -1,48 +1,14 @@
1
  from io import StringIO
2
- import csv
3
  import json
4
  import logging
5
 
6
  import torch
7
  from transformers import pipeline
8
- from pydantic import BaseModel, ValidationError, validator
9
 
10
- logger = logging.getLogger(__name__)
11
-
12
- class Card(BaseModel):
13
- question: str
14
- answer: str
15
-
16
- class Message(BaseModel):
17
- role: str
18
- content: list[Card]
19
-
20
- @validator('content', pre=True)
21
- def parse_content(cls, v):
22
- if isinstance(v, str):
23
- try:
24
- content_list = json.loads(v)
25
- return content_list
26
- except json.JSONDecodeError as e:
27
- raise ValueError(f"Error decoding 'content' JSON: {e}") from e
28
- return v
29
-
30
- def content_to_json(self) -> str:
31
- return json.dumps([card.dict() for card in self.content], indent=2)
32
-
33
- def content_to_csv(self) -> str:
34
- output = StringIO()
35
- writer = csv.writer(output)
36
- writer.writerow(['Question', 'Answer']) # CSV Header
37
- for card in self.content:
38
- writer.writerow([card.question, card.answer])
39
- return output.getvalue()
40
 
41
- class PydanticEncoder(json.JSONEncoder):
42
- def default(self, obj):
43
- if isinstance(obj, BaseModel):
44
- return obj.dict()
45
- return super().default(obj)
46
 
47
  class Pipeline:
48
  def __init__(self, model_name: str = "Qwen/Qwen2.5-7B-Instruct"):
@@ -53,6 +19,7 @@ class Pipeline:
53
  device_map="auto"
54
  )
55
  self.device = self._determine_device()
 
56
  self.messages = [
57
  {"role": "system", "content": """You are an expert flashcard creator. You always include a single knowledge item per flashcard.
58
  - You ALWAYS include a single knowledge item per flashcard.
@@ -77,46 +44,9 @@ class Pipeline:
77
  )[0]["generated_text"][-1]
78
  return response_message
79
 
80
- def format_flashcards(self, output_format: str, response: str) -> str:
81
- output = ""
82
- try :
83
- message = parse_message(response)
84
- logger.debug("after parse_obj_as")
85
- except ValidationError as e:
86
- raise e
87
- if output_format.lower() == "json":
88
- output = message.content_to_json()
89
- elif output_format.lower() == "csv":
90
- output = message.content_to_csv()
91
- return output
92
-
93
  def generate_flashcards(self, output_format: str, content: str) -> str:
94
  response = self.extract_flashcards(content)
95
- return self.format_flashcards(output_format, response)
96
-
97
- def parse_message(self, input_dict: dict[str, any]) -> Message:
98
- try:
99
- # Extract the role
100
- role = input_dict['role']
101
-
102
- # Parse the content
103
- content = input_dict['content']
104
- # If content is a string, try to parse it as JSON
105
- if isinstance(content, str):
106
- content = content.strip()
107
- content = json.loads(content)
108
-
109
- # Create Card objects from the content
110
- cards = [Card(**item) for item in content]
111
-
112
- # Create and return the Message object
113
- return Message(role=role, content=cards)
114
- except json.JSONDecodeError as e:
115
- raise ValueError(f"Invalid JSON in content: {str(e)}")
116
- except ValidationError as e:
117
- raise ValueError(f"Validation error: {str(e)}")
118
- except KeyError as e:
119
- raise ValueError(f"Missing required key: {str(e)}")
120
 
121
  def _determine_device(self):
122
  if torch.cuda.is_available():
@@ -144,8 +74,11 @@ def parse_message(input_dict: dict[str, any]) -> Message:
144
  # Create and return the Message object
145
  return Message(role=role, content=cards)
146
  except json.JSONDecodeError as e:
 
147
  raise ValueError(f"Invalid JSON in content: {str(e)}")
148
  except ValidationError as e:
 
149
  raise ValueError(f"Validation error: {str(e)}")
150
  except KeyError as e:
 
151
  raise ValueError(f"Missing required key: {str(e)}")
 
1
  from io import StringIO
 
2
  import json
3
  import logging
4
 
5
  import torch
6
  from transformers import pipeline
 
7
 
8
+ from .models import Card, Message, ValidationError
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ logger = logging.getLogger(__name__)
11
+ logging.basicConfig(filename="pipeline.log", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S")
 
 
 
12
 
13
  class Pipeline:
14
  def __init__(self, model_name: str = "Qwen/Qwen2.5-7B-Instruct"):
 
19
  device_map="auto"
20
  )
21
  self.device = self._determine_device()
22
+ logger.info(f"device type: {self.device}")
23
  self.messages = [
24
  {"role": "system", "content": """You are an expert flashcard creator. You always include a single knowledge item per flashcard.
25
  - You ALWAYS include a single knowledge item per flashcard.
 
44
  )[0]["generated_text"][-1]
45
  return response_message
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  def generate_flashcards(self, output_format: str, content: str) -> str:
48
  response = self.extract_flashcards(content)
49
+ return format_flashcards(output_format, response)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  def _determine_device(self):
52
  if torch.cuda.is_available():
 
74
  # Create and return the Message object
75
  return Message(role=role, content=cards)
76
  except json.JSONDecodeError as e:
77
+ logger.error(f"Invalid JSON in content: {str(e)}")
78
  raise ValueError(f"Invalid JSON in content: {str(e)}")
79
  except ValidationError as e:
80
+ logger.error(f"Validation error: {str(e)}")
81
  raise ValueError(f"Validation error: {str(e)}")
82
  except KeyError as e:
83
+ logger.error(f"Missing required key: {str(e)}")
84
  raise ValueError(f"Missing required key: {str(e)}")
app/processing.py CHANGED
@@ -22,47 +22,6 @@ def read_text_file(file_path: str) -> str:
22
  except Exception as e:
23
  raise ValueError(f"Error reading text file: {str(e)}")
24
 
25
- def format_prompt(output_format: str) -> str:
26
- """
27
- Formats the prompt based on the output type.
28
- """
29
- if output_format.lower() == "json":
30
- return """You only respond in JSON format. Follow the example below.
31
-
32
- EXAMPLE:
33
- [
34
- {"question": "What is AI?", "answer": "Artificial Intelligence."},
35
- {"question": "What is ML?", "answer": "Machine Learning."}
36
- ]
37
- """
38
- elif output_format.lower() == "csv":
39
- return """You only respond with cards in CSV format. Follow the example below.
40
-
41
- EXAMPLE:
42
- "What is AI?", "Artificial Intelligence."
43
- "What is ML?", "Machine Learning."
44
- """
45
-
46
- # def extract_flashcards(text: str, output_format: str, pipeline: str) -> str:
47
- # """
48
- # Extracts flashcards from the input text using the LLM and formats them in CSV or JSON.
49
- # """
50
- # prompt = f"""You are an expert flashcard creator. You always include a single knowledge item per flashcard.
51
-
52
- # {format_prompt(output_format)}
53
-
54
-
55
- # Extract flashcards from the user's text:
56
-
57
- # {text}
58
-
59
- # Do not include the prompt or any other unnecessary information in the flashcards.
60
- # Do not include triple ticks (```) or any other code blocks in the flashcards.
61
- # """
62
- # # TODO:
63
- # response = pipeline.generate_flashcards("json", prompt)
64
- # return response
65
-
66
  def process_file(file_obj, output_format: str, pipeline) -> str:
67
  """
68
  Processes the uploaded file based on its type and extracts flashcards.
@@ -89,3 +48,17 @@ def process_text_input(output_format: str, input_text: str) -> str:
89
 
90
  flashcards = pipeline.generate_flashcards(output_format, input_text)
91
  return flashcards
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  except Exception as e:
23
  raise ValueError(f"Error reading text file: {str(e)}")
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  def process_file(file_obj, output_format: str, pipeline) -> str:
26
  """
27
  Processes the uploaded file based on its type and extracts flashcards.
 
48
 
49
  flashcards = pipeline.generate_flashcards(output_format, input_text)
50
  return flashcards
51
+
52
+
53
+ def format_flashcards(self, output_format: str, response: str) -> str:
54
+ output = ""
55
+ try :
56
+ message = parse_message(response)
57
+ logger.debug("after parse_obj_as")
58
+ except ValidationError as e:
59
+ raise e
60
+ if output_format.lower() == "json":
61
+ output = message.content_to_json()
62
+ elif output_format.lower() == "csv":
63
+ output = message.content_to_csv()
64
+ return output
tests/conftest.py CHANGED
@@ -1,6 +1,6 @@
1
  import pytest
2
  from unittest.mock import Mock
3
- from app.pipeline import LanguageModel
4
 
5
  @pytest.fixture
6
  def pipeline():
@@ -8,7 +8,7 @@ def pipeline():
8
  Fixture to provide a mocked LanguageModel instance.
9
  """
10
  # Create a mock instance of LanguageModel
11
- lm = Mock(spec=LanguageModel)
12
  # Mock the generate_flashcards method
13
  lm.generate_flashcards.return_value = '{"flashcards": []}'
14
  return lm
 
1
  import pytest
2
  from unittest.mock import Mock
3
+ from app.pipeline import Pipeline
4
 
5
  @pytest.fixture
6
  def pipeline():
 
8
  Fixture to provide a mocked LanguageModel instance.
9
  """
10
  # Create a mock instance of LanguageModel
11
+ lm = Mock(spec=Pipeline)
12
  # Mock the generate_flashcards method
13
  lm.generate_flashcards.return_value = '{"flashcards": []}'
14
  return lm
tests/test_models.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ from app.models import Card, Message
4
+
5
+ # Tests for Card and Message models
6
+ def test_card_model():
7
+ card = Card(question="What is Python?", answer="A programming language")
8
+ assert card.question == "What is Python?"
9
+ assert card.answer == "A programming language"
10
+
11
+ def test_message_model():
12
+ cards = [
13
+ Card(question="What is AI?", answer="Artificial Intelligence"),
14
+ Card(question="What is ML?", answer="Machine Learning")
15
+ ]
16
+ message = Message(role="assistant", content=cards)
17
+ assert message.role == "assistant"
18
+ assert len(message.content) == 2
19
+ assert message.content[0].question == "What is AI?"
20
+
21
+ def test_message_content_json_parsing():
22
+ json_content = '[{"question": "Q1", "answer": "A1"}, {"question": "Q2", "answer": "A2"}]'
23
+ message = Message(role="assistant", content=json_content)
24
+ assert len(message.content) == 2
25
+ assert message.content[0].question == "Q1"
26
+
27
+ def test_message_content_to_json():
28
+ cards = [Card(question="Q1", answer="A1"), Card(question="Q2", answer="A2")]
29
+ message = Message(role="assistant", content=cards)
30
+ json_output = message.content_to_json()
31
+ assert json.loads(json_output) == [
32
+ {"question": "Q1", "answer": "A1"},
33
+ {"question": "Q2", "answer": "A2"}
34
+ ]
35
+
36
+ # failed test
37
+ def test_message_content_to_csv():
38
+ cards = [Card(question="Q1", answer="A1"), Card(question="Q2", answer="A2")]
39
+ message = Message(role="assistant", content=cards)
40
+ csv_output = message.content_to_csv()
41
+ expected_output = "Question,Answer\nQ1,A1\nQ2,A2\n" # Use Unix-style line endings
42
+ print(csv_output) # Optional: for debugging purposes
43
+ assert csv_output == expected_output
tests/test_pipeline.py CHANGED
@@ -1,18 +1,67 @@
1
  import pytest
 
 
 
 
 
 
2
 
3
- def test_generate_flashcards(pipeline, mocker):
4
- """
5
- Test the generate_flashcards method of LanguageModel.
6
- """
7
- prompt = "Sample prompt for flashcard generation."
8
- expected_response = '{"flashcards": [{"Question": "What is AI?", "Answer": "Artificial Intelligence."}]}'
9
 
10
- # Configure the mock to return a specific response
11
- pipeline.generate_flashcards.return_value = expected_response
 
 
 
12
 
13
- # Call the method
14
- response = pipeline.generate_flashcards(prompt)
 
 
15
 
16
- # Assertions
17
- assert response == expected_response
18
- pipeline.generate_flashcards.assert_called_once_with(prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import pytest
2
+ from unittest.mock import Mock, patch
3
+ import json
4
+ from io import StringIO
5
+ from pydantic import ValidationError
6
+ from app.pipeline import Pipeline, Message, Card, parse_message
7
+ from app.models import PydanticEncoder
8
 
9
+ # Tests for Pipeline class
10
+ @pytest.fixture
11
+ def mock_pipeline():
12
+ with patch('app.pipeline') as mock_pipe:
13
+ mock_pipe.return_value = Mock()
14
+ yield Pipeline("mock_model")
15
 
16
+ # def test_extract_flashcards(mock_pipeline):
17
+ # mock_pipeline.torch_pipe.return_value = [{"generated_text": [{"role": "assistant", "content": '[{"question": "Q", "answer": "A"}]'}]}]
18
+ # response = mock_pipeline.extract_flashcards("Test content")
19
+ # assert isinstance(response, dict)
20
+ # assert "content" in response
21
 
22
+ # def test_format_flashcards_csv(mock_pipeline):
23
+ # response = {"role": "assistant", "content": '[{"question": "Q", "answer": "A"}]'}
24
+ # formatted = mock_pipeline.format_flashcards("csv", response)
25
+ # assert formatted.strip() == "Question,Answer\nQ,A"
26
 
27
+ # def test_generate_flashcards(mock_pipeline):
28
+ # mock_pipeline.extract_flashcards.return_value = {"role": "assistant", "content": '[{"question": "Q", "answer": "A"}]'}
29
+ # result = mock_pipeline.generate_flashcards("json", "Test content")
30
+ # assert json.loads(result) == [{"question": "Q", "answer": "A"}]
31
+
32
+ # Tests for parse_message function
33
+ def test_parse_message_valid_input():
34
+ input_dict = {
35
+ "role": "assistant",
36
+ "content": '[{"question": "Q1", "answer": "A1"}, {"question": "Q2", "answer": "A2"}]'
37
+ }
38
+ message = parse_message(input_dict)
39
+ assert isinstance(message, Message)
40
+ assert message.role == "assistant"
41
+ assert len(message.content) == 2
42
+
43
+ def test_parse_message_invalid_json():
44
+ input_dict = {
45
+ "role": "assistant",
46
+ "content": 'Invalid JSON'
47
+ }
48
+ with pytest.raises(ValueError, match="Invalid JSON in content"):
49
+ parse_message(input_dict)
50
+
51
+ def test_parse_message_missing_key():
52
+ input_dict = {
53
+ "content": '[{"question": "Q", "answer": "A"}]'
54
+ }
55
+ with pytest.raises(ValueError, match="Missing required key"):
56
+ parse_message(input_dict)
57
+
58
+ # Test for PydanticEncoder
59
+ def test_pydantic_encoder():
60
+ card = Card(question="Q", answer="A")
61
+ encoded = json.dumps(card, cls=PydanticEncoder)
62
+ assert json.loads(encoded) == {"question": "Q", "answer": "A"}
63
+
64
+ # Test error cases
65
+ def test_message_invalid_content():
66
+ with pytest.raises(ValidationError):
67
+ Message(role="assistant", content="Invalid content")
tests/test_processing.py CHANGED
@@ -1,71 +1,40 @@
1
  import pytest
2
- from app.processing import process_text_input, process_file
3
-
4
- def test_process_text_input_success(pipeline):
5
- """
6
- Test processing of valid text input.
7
- """
8
- input_text = "This is a sample text for flashcard extraction."
9
- output_format = "JSON"
10
- expected_output = '{"flashcards": []}'
11
-
12
- result = process_text_input(input_text, output_format, pipeline)
13
- assert result == expected_output
14
- pipeline.generate_flashcards.assert_called_once()
15
-
16
- def test_process_text_input_empty(pipeline):
17
- """
18
- Test processing of empty text input.
19
- """
20
- input_text = " "
21
- output_format = "JSON"
22
-
23
- with pytest.raises(ValueError) as excinfo:
24
- process_text_input(input_text, output_format, pipeline)
25
- assert "No text provided." in str(excinfo.value)
26
-
27
- def test_process_file_unsupported_type(pipeline, tmp_path):
28
- """
29
- Test processing of an unsupported file type.
30
- """
31
- # Create a dummy unsupported file
32
- dummy_file = tmp_path / "dummy.unsupported"
33
- dummy_file.write_text("Unsupported content")
34
-
35
- with pytest.raises(ValueError) as excinfo:
36
- process_file(dummy_file, "JSON", pipeline)
37
- assert "Unsupported file type." in str(excinfo.value)
38
-
39
- def test_process_file_pdf(pipeline, tmp_path, mocker):
40
- """
41
- Test processing of a PDF file.
42
- """
43
- # Mock the process_pdf function
44
- mocker.patch('app.processing.process_pdf', return_value="Extracted PDF text.")
45
-
46
- # Create a dummy PDF file
47
- dummy_file = tmp_path / "test.pdf"
48
- dummy_file.write_text("PDF content")
49
-
50
- expected_output = '{"flashcards": []}'
51
-
52
- result = process_file(dummy_file, "JSON", pipeline)
53
- assert result == expected_output
54
- pipeline.generate_flashcards.assert_called_once()
55
-
56
- def test_process_file_txt(pipeline, tmp_path, mocker):
57
- """
58
- Test processing of a TXT file.
59
- """
60
- # Mock the read_text_file function
61
- mocker.patch('app.processing.read_text_file', return_value="Extracted TXT text.")
62
-
63
- # Create a dummy TXT file
64
- dummy_file = tmp_path / "test.txt"
65
- dummy_file.write_text("TXT content")
66
-
67
- expected_output = '{"flashcards": []}'
68
-
69
- result = process_file(dummy_file, "JSON", pipeline)
70
- assert result == expected_output
71
- pipeline.generate_flashcards.assert_called_once()
 
1
  import pytest
2
+ from unittest.mock import patch, Mock
3
+ from app.processing import process_pdf, read_text_file, process_file, process_text_input
4
+
5
+ def test_read_text_file_error():
6
+ with patch("builtins.open", side_effect=IOError("File read error")):
7
+ with pytest.raises(ValueError, match="Error reading text file: File read error"):
8
+ read_text_file("test.txt")
9
+
10
+ # Test for process_file function
11
+ def test_process_file_pdf(pipeline):
12
+ mock_file = Mock()
13
+ mock_file.name = "test.pdf"
14
+
15
+ with patch('app.processing.process_pdf', return_value="PDF content"):
16
+ result = process_file(mock_file, "json", pipeline)
17
+ pipeline.generate_flashcards.assert_called_once_with("json", "PDF content")
18
+ assert result == '{"flashcards": []}'
19
+
20
+ def test_process_file_txt(pipeline):
21
+ mock_file = Mock()
22
+ mock_file.name = "test.txt"
23
+
24
+ with patch('app.processing.read_text_file', return_value="Text content"):
25
+ result = process_file(mock_file, "json", pipeline)
26
+ pipeline.generate_flashcards.assert_called_once_with("json", "Text content")
27
+ assert result == '{"flashcards": []}'
28
+
29
+ def test_process_file_unsupported():
30
+ mock_file = Mock()
31
+ mock_file.name = "test.unsupported"
32
+
33
+ with pytest.raises(ValueError, match="Unsupported file type."):
34
+ process_file(mock_file, "json", None)
35
+
36
+ # Ensure the pipeline fixture is used in all tests that require it
37
+ @pytest.mark.usefixtures("pipeline")
38
+ class TestWithPipeline:
39
+ def test_pipeline_usage(self, pipeline):
40
+ assert pipeline.generate_flashcards.return_value == '{"flashcards": []}'