gabykim commited on
Commit
d9d9220
·
1 Parent(s): 4e58eba

docstring before function definition

Browse files
src/know_lang_bot/code_parser/parser.py CHANGED
@@ -14,7 +14,6 @@ LOG = FancyLogger(__name__)
14
  class ChunkType(str, Enum):
15
  CLASS = "class"
16
  FUNCTION = "function"
17
- MODULE = "module"
18
  OTHER = "other"
19
 
20
  class CodeChunk(BaseModel):
@@ -44,14 +43,36 @@ class CodeParser:
44
  self.language = Language(tree_sitter_python.language())
45
  self.parser = Parser(self.language)
46
 
47
- def _extract_docstring(self, node: Node, source_code: bytes) -> Optional[str]:
48
- """Extract docstring from a class or function node"""
49
- for child in node.children:
50
- if child.type == "expression_statement":
51
- string_node = child.children[0]
52
- if string_node.type in ("string", "string_literal"):
53
- return source_code[string_node.start_byte:string_node.end_byte].decode('utf-8')
54
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  def parse_file(self, file_path: Path) -> List[CodeChunk]:
57
  """Parse a single file and return list of code chunks"""
@@ -64,6 +85,12 @@ class CodeParser:
64
  source_code = f.read()
65
 
66
  tree = self.parser.parse(source_code)
 
 
 
 
 
 
67
  chunks: List[CodeChunk] = []
68
 
69
  # Process the syntax tree
@@ -73,15 +100,8 @@ class CodeParser:
73
  elif node.type == "function_definition":
74
  chunks.append(self._process_function(node, source_code, file_path))
75
  else:
76
- # Store other top-level code as separate chunks
77
- if node.type not in ("comment", "empty_statement"):
78
- chunks.append(CodeChunk(
79
- type=ChunkType.OTHER,
80
- content=source_code[node.start_byte:node.end_byte].decode('utf-8'),
81
- start_line=node.start_point[0],
82
- end_line=node.end_point[0],
83
- file_path=str(file_path)
84
- ))
85
 
86
  return chunks
87
  except Exception as e:
@@ -94,6 +114,9 @@ class CodeParser:
94
  for child in node.children
95
  if child.type == "identifier")
96
 
 
 
 
97
  return CodeChunk(
98
  type=ChunkType.CLASS,
99
  name=name,
@@ -101,7 +124,7 @@ class CodeParser:
101
  start_line=node.start_point[0],
102
  end_line=node.end_point[0],
103
  file_path=str(file_path),
104
- docstring=self._extract_docstring(node, source_code)
105
  )
106
 
107
  def _process_function(self, node: Node, source_code: bytes, file_path: Path) -> CodeChunk:
@@ -109,6 +132,9 @@ class CodeParser:
109
  name = next(child.text.decode('utf-8')
110
  for child in node.children
111
  if child.type == "identifier")
 
 
 
112
 
113
  return CodeChunk(
114
  type=ChunkType.FUNCTION,
@@ -117,7 +143,7 @@ class CodeParser:
117
  start_line=node.start_point[0],
118
  end_line=node.end_point[0],
119
  file_path=str(file_path),
120
- docstring=self._extract_docstring(node, source_code)
121
  )
122
 
123
  def parse_repository(self) -> List[CodeChunk]:
 
14
  class ChunkType(str, Enum):
15
  CLASS = "class"
16
  FUNCTION = "function"
 
17
  OTHER = "other"
18
 
19
  class CodeChunk(BaseModel):
 
43
  self.language = Language(tree_sitter_python.language())
44
  self.parser = Parser(self.language)
45
 
46
+ def _get_preceding_docstring(self, node: Node, source_code: bytes) -> Optional[str]:
47
+ """Extract docstring from comments"""
48
+ docstring_parts = []
49
+ current_node : Node = node.prev_sibling
50
+
51
+ while current_node:
52
+ print(current_node.text)
53
+ if current_node.type == "comment":
54
+ comment = source_code[current_node.start_byte:current_node.end_byte].decode('utf-8')
55
+ docstring_parts.insert(0, comment)
56
+ elif current_node.type == "expression_statement":
57
+ string_node = current_node.children[0] if current_node.children else None
58
+ if string_node and string_node.type in ("string", "string_literal"):
59
+ docstring = source_code[string_node.start_byte:string_node.end_byte].decode('utf-8')
60
+ docstring_parts.insert(0, docstring)
61
+
62
+ break
63
+ elif current_node.type not in ("empty_statement", "newline"):
64
+ break
65
+ current_node = current_node.prev_sibling
66
+
67
+ return '\n'.join(docstring_parts) if docstring_parts else None
68
+
69
+ def _has_syntax_error(self, node: Node) -> bool:
70
+ """Check if the node or its children contain syntax errors"""
71
+ if node.type == "ERROR":
72
+ return True
73
+ if node.has_error:
74
+ return True
75
+ return any(self._has_syntax_error(child) for child in node.children)
76
 
77
  def parse_file(self, file_path: Path) -> List[CodeChunk]:
78
  """Parse a single file and return list of code chunks"""
 
85
  source_code = f.read()
86
 
87
  tree = self.parser.parse(source_code)
88
+
89
+ # Check for overall syntax validity
90
+ if self._has_syntax_error(tree.root_node):
91
+ LOG.warning(f"Syntax errors found in {file_path}")
92
+ return []
93
+
94
  chunks: List[CodeChunk] = []
95
 
96
  # Process the syntax tree
 
100
  elif node.type == "function_definition":
101
  chunks.append(self._process_function(node, source_code, file_path))
102
  else:
103
+ # Skip other node types for now
104
+ pass
 
 
 
 
 
 
 
105
 
106
  return chunks
107
  except Exception as e:
 
114
  for child in node.children
115
  if child.type == "identifier")
116
 
117
+ if not name:
118
+ raise ValueError(f"Could not find class name in node: {node.text}")
119
+
120
  return CodeChunk(
121
  type=ChunkType.CLASS,
122
  name=name,
 
124
  start_line=node.start_point[0],
125
  end_line=node.end_point[0],
126
  file_path=str(file_path),
127
+ docstring=self._get_preceding_docstring(node, source_code)
128
  )
129
 
130
  def _process_function(self, node: Node, source_code: bytes, file_path: Path) -> CodeChunk:
 
132
  name = next(child.text.decode('utf-8')
133
  for child in node.children
134
  if child.type == "identifier")
135
+
136
+ if not name:
137
+ raise ValueError(f"Could not find function name in node: {node.text}")
138
 
139
  return CodeChunk(
140
  type=ChunkType.FUNCTION,
 
143
  start_line=node.start_point[0],
144
  end_line=node.end_point[0],
145
  file_path=str(file_path),
146
+ docstring=self._get_preceding_docstring(node, source_code)
147
  )
148
 
149
  def parse_repository(self) -> List[CodeChunk]:
tests/test_constants.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, NamedTuple
2
+
3
+ class ExpectedChunk(NamedTuple):
4
+ name: str
5
+ docstring: str
6
+ content_snippet: str # A unique part of the content that should be present
7
+
8
+ # Test file contents
9
+ SIMPLE_FUCNTION_CLEANED_DOCSTRING = "Say hello to the world"
10
+ SIMPLE_FUNCTION_DOCSTRING = f'\"\"\"{SIMPLE_FUCNTION_CLEANED_DOCSTRING}\"\"\"'
11
+ SIMPLE_FUNCTION = f'''
12
+ {SIMPLE_FUNCTION_DOCSTRING}
13
+ def hello_world():
14
+ return "Hello, World!"
15
+ '''
16
+
17
+ SIMPLE_CLASS_DOCSTRING = f'\"\"\"A simple class for testing\"\"\"'
18
+ SIMPLE_CLASS = f'''
19
+ {SIMPLE_CLASS_DOCSTRING}
20
+ class SimpleClass:
21
+ def __init__(self):
22
+ self.value = 42
23
+
24
+ def get_value(self):
25
+ return self.value
26
+ '''
27
+
28
+ NESTED_OUTER_CLASS_DOCSTRING = "#Outer class docstring"
29
+ NESTED_CLASS = f'''
30
+ {NESTED_OUTER_CLASS_DOCSTRING}
31
+ class OuterClass:
32
+ class InnerClass:
33
+ """Inner class docstring"""
34
+ def inner_method(self):
35
+ return "inner"
36
+
37
+ def outer_method(self):
38
+ return "outer"
39
+ '''
40
+
41
+
42
+ COMPLEX_FUNCTION_DOCSTRING = f"""\"\"\"
43
+ A complex function with type hints and docstring
44
+ Args:
45
+ param1: First parameter
46
+ param2: Optional second parameter
47
+ Returns:
48
+ List of strings
49
+ \"\"\""""
50
+
51
+ COMPLEX_CLASS_DOCSTRING = "# # #Complex class implementation"
52
+ COMPLEX_FILE = f'''
53
+ import os
54
+ from typing import List, Optional
55
+
56
+ {COMPLEX_FUNCTION_DOCSTRING}
57
+ def complex_function(param1: str, param2: Optional[int] = None) -> List[str]:
58
+ results = []
59
+ if param2 is not None:
60
+ results.extend([param1] * param2)
61
+ return results
62
+
63
+ # Some comment
64
+ CONSTANT = 42
65
+
66
+ {COMPLEX_CLASS_DOCSTRING}
67
+ class ComplexClass:
68
+ """Complex class implementation Test Test"""
69
+ def __init__(self):
70
+ self._value = None
71
+ '''
72
+
73
+ INVALID_SYNTAX = '''def invalid_syntax(:'''
74
+
75
+ # Expected test results
76
+ SIMPLE_FILE_EXPECTATIONS = {
77
+ 'hello_world': ExpectedChunk(
78
+ name="hello_world",
79
+ docstring=SIMPLE_FUNCTION_DOCSTRING,
80
+ content_snippet='return "Hello, World!"'
81
+ ),
82
+ 'SimpleClass': ExpectedChunk(
83
+ name="SimpleClass",
84
+ docstring=SIMPLE_CLASS_DOCSTRING,
85
+ content_snippet='self.value = 42'
86
+ )
87
+ }
88
+
89
+ NESTED_CLASS_EXPECTATIONS = {
90
+ 'OuterClass': ExpectedChunk(
91
+ name="OuterClass",
92
+ docstring=NESTED_OUTER_CLASS_DOCSTRING,
93
+ content_snippet='class InnerClass'
94
+ )
95
+ }
96
+
97
+ COMPLEX_FILE_EXPECTATIONS = {
98
+ 'complex_function': ExpectedChunk(
99
+ name="complex_function",
100
+ docstring=COMPLEX_FUNCTION_DOCSTRING,
101
+ content_snippet='List[str]'
102
+ ),
103
+ 'ComplexClass': ExpectedChunk(
104
+ name="ComplexClass",
105
+ docstring=COMPLEX_CLASS_DOCSTRING,
106
+ content_snippet='_value = None'
107
+ )
108
+ }
109
+
110
+ TEST_FILES = {
111
+ 'simple.py': SIMPLE_FUNCTION + SIMPLE_CLASS,
112
+ 'nested.py': NESTED_CLASS,
113
+ 'complex.py': COMPLEX_FILE
114
+ }
tests/test_parser.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from know_lang_bot.code_parser.parser import CodeChunk, CodeParser, ChunkType
2
+ from pathlib import Path
3
+ from tests.test_constants import (
4
+ SIMPLE_FILE_EXPECTATIONS,
5
+ NESTED_CLASS_EXPECTATIONS,
6
+ COMPLEX_FILE_EXPECTATIONS,
7
+ INVALID_SYNTAX,
8
+ TEST_FILES,
9
+ )
10
+ import pytest
11
+ import tempfile
12
+ import git
13
+
14
+
15
+ @pytest.fixture
16
+ def temp_repo():
17
+ """Create a temporary git repository with sample Python files"""
18
+ with tempfile.TemporaryDirectory() as temp_dir:
19
+ # Initialize git repo
20
+ repo = git.Repo.init(temp_dir)
21
+
22
+ # Create sample Python files
23
+ for filename, content in TEST_FILES.items():
24
+ file_path = Path(temp_dir) / filename
25
+ file_path.write_text(content)
26
+ repo.index.add([str(file_path)])
27
+
28
+ repo.index.commit("Initial commit")
29
+
30
+ yield temp_dir
31
+
32
+ def find_chunk_by_criteria(chunks: list[CodeChunk], **criteria) -> CodeChunk:
33
+ """Helper function to find a chunk matching given criteria"""
34
+ for chunk in chunks:
35
+ if all(getattr(chunk, k) == v for k, v in criteria.items()):
36
+ return chunk
37
+ return None
38
+
39
+ def test_init_parser(temp_repo):
40
+ """Test parser initialization"""
41
+ parser = CodeParser(temp_repo)
42
+ assert parser.repo_path == Path(temp_repo)
43
+ assert parser.language is not None
44
+ assert parser.parser is not None
45
+
46
+ def test_parse_simple_file(temp_repo):
47
+ """Test parsing a simple Python file with function and class"""
48
+ parser = CodeParser(temp_repo)
49
+ chunks = parser.parse_file(Path(temp_repo) / "simple.py")
50
+
51
+ # Test function
52
+ function_chunk = find_chunk_by_criteria(chunks, type=ChunkType.FUNCTION, name="hello_world")
53
+ assert function_chunk is not None
54
+ expected = SIMPLE_FILE_EXPECTATIONS['hello_world']
55
+ assert expected.content_snippet in function_chunk.content
56
+ assert function_chunk.docstring is not None
57
+ assert function_chunk.docstring in expected.docstring
58
+
59
+ # Test class
60
+ class_chunk = find_chunk_by_criteria(chunks, type=ChunkType.CLASS, name="SimpleClass")
61
+ assert class_chunk is not None
62
+ expected = SIMPLE_FILE_EXPECTATIONS['SimpleClass']
63
+ assert expected.content_snippet in class_chunk.content
64
+ assert class_chunk.docstring is not None
65
+ assert class_chunk.docstring in expected.docstring
66
+
67
+
68
+ def test_parse_nested_classes(temp_repo):
69
+ """Test parsing nested class definitions"""
70
+ parser = CodeParser(temp_repo)
71
+ chunks = parser.parse_file(Path(temp_repo) / "nested.py")
72
+
73
+ # Test outer class
74
+ outer_class = find_chunk_by_criteria(chunks, type=ChunkType.CLASS, name="OuterClass")
75
+ assert outer_class is not None
76
+ expected = NESTED_CLASS_EXPECTATIONS['OuterClass']
77
+ assert expected.content_snippet in outer_class.content
78
+ assert outer_class.docstring is not None
79
+ assert outer_class.docstring in expected.docstring
80
+
81
+ # Verify inner class: Not implemented yet
82
+ pass
83
+
84
+ def test_parse_complex_file(temp_repo):
85
+ """Test parsing a complex Python file"""
86
+ parser = CodeParser(temp_repo)
87
+ chunks = parser.parse_file(Path(temp_repo) / "complex.py")
88
+
89
+ # Test function with type hints
90
+ complex_func = find_chunk_by_criteria(
91
+ chunks,
92
+ type=ChunkType.FUNCTION,
93
+ name="complex_function"
94
+ )
95
+ assert complex_func is not None
96
+ expected = COMPLEX_FILE_EXPECTATIONS['complex_function']
97
+ assert expected.content_snippet in complex_func.content
98
+ assert complex_func.docstring is not None
99
+ assert complex_func.docstring in expected.docstring
100
+
101
+ # Test complex class
102
+ complex_class = find_chunk_by_criteria(
103
+ chunks,
104
+ type=ChunkType.CLASS,
105
+ name="ComplexClass"
106
+ )
107
+ assert complex_class is not None
108
+ expected = COMPLEX_FILE_EXPECTATIONS['ComplexClass']
109
+ assert expected.content_snippet in complex_class.content
110
+ assert complex_class.docstring is not None
111
+ assert complex_class.docstring in expected.docstring
112
+
113
+
114
+ def test_parse_repository(temp_repo):
115
+ """Test parsing entire repository"""
116
+ parser = CodeParser(temp_repo)
117
+ chunks = parser.parse_repository()
118
+
119
+ file_paths = {chunk.file_path for chunk in chunks}
120
+ assert len(file_paths) == 3
121
+
122
+ # Verify we can find chunks from each test file
123
+ for filename in TEST_FILES.keys():
124
+ file_chunks = [c for c in chunks if Path(c.file_path).name == filename]
125
+ assert len(file_chunks) > 0
126
+
127
+ def test_error_handling(temp_repo):
128
+ """Test error handling for invalid files"""
129
+ parser = CodeParser(temp_repo)
130
+
131
+ # Test invalid syntax
132
+ invalid_file = Path(temp_repo) / "invalid.py"
133
+ invalid_file.write_text(INVALID_SYNTAX)
134
+ chunks = parser.parse_file(invalid_file)
135
+ assert chunks == []
136
+
137
+ # Test non-existent file
138
+ nonexistent = Path(temp_repo) / "nonexistent.py"
139
+ chunks = parser.parse_file(nonexistent)
140
+ assert chunks == []
141
+
142
+ def test_non_python_files(temp_repo):
143
+ """Test handling of non-Python files"""
144
+ parser = CodeParser(temp_repo)
145
+
146
+ # Create a non-Python file
147
+ non_python = Path(temp_repo) / "readme.md"
148
+ non_python.write_text("# README")
149
+
150
+ # Should skip non-Python files
151
+ chunks = parser.parse_file(non_python)
152
+ assert chunks == []