Spaces:
Sleeping
Sleeping
File size: 6,067 Bytes
71fa0c7 60532a1 71fa0c7 c20abe2 71fa0c7 2182a08 71fa0c7 c20abe2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
from typing import List, Optional
from pathlib import Path
import tree_sitter_python
from tree_sitter import Language, Parser, Node
from knowlang.parser.base.parser import LanguageParser
from knowlang.core.types import CodeChunk, ChunkType
from knowlang.utils.fancy_log import FancyLogger
LOG = FancyLogger(__name__)
class PythonParser(LanguageParser):
"""Python-specific implementation of LanguageParser"""
def setup(self) -> None:
"""Initialize tree-sitter with Python language support"""
self.language = Language(tree_sitter_python.language())
self.parser = Parser(self.language)
self.language_config = self.config.parser.languages["python"]
def _get_preceding_docstring(self, node: Node, source_code: bytes) -> Optional[str]:
"""Extract docstring from comments"""
docstring_parts = []
current_node = node.prev_sibling
while current_node:
if current_node.type == "comment":
comment = source_code[current_node.start_byte:current_node.end_byte].decode('utf-8')
docstring_parts.insert(0, comment)
elif current_node.type == "expression_statement":
string_node = current_node.children[0] if current_node.children else None
if string_node and string_node.type in ("string", "string_literal"):
docstring = source_code[string_node.start_byte:string_node.end_byte].decode('utf-8')
docstring_parts.insert(0, docstring)
break
elif current_node.type not in ("empty_statement", "newline"):
break
current_node = current_node.prev_sibling
return '\n'.join(docstring_parts) if docstring_parts else None
def _has_syntax_error(self, node: Node) -> bool:
"""Check if the node or its children contain syntax errors"""
if node.type == "ERROR":
return True
if node.has_error:
return True
return any(self._has_syntax_error(child) for child in node.children)
def _process_class(self, node: Node, source_code: bytes, file_path: Path) -> CodeChunk:
"""Process a class node and return a CodeChunk"""
name = next(
(child.text.decode('utf-8')
for child in node.children
if child.type == "identifier"),
None
)
if not name:
raise ValueError(f"Could not find class name in node: {node.text}")
return CodeChunk(
type=ChunkType.CLASS,
name=name,
content=source_code[node.start_byte:node.end_byte].decode('utf-8'),
start_line=node.start_point[0],
end_line=node.end_point[0],
file_path=str(file_path),
docstring=self._get_preceding_docstring(node, source_code)
)
def _process_function(self, node: Node, source_code: bytes, file_path: Path) -> CodeChunk:
"""Process a function node and return a CodeChunk"""
name = next(
(child.text.decode('utf-8')
for child in node.children
if child.type == "identifier"),
None
)
if not name:
raise ValueError(f"Could not find function name in node: {node.text}")
# Determine if this is a method within a class
parent_node = node.parent
parent_name = None
if parent_node and parent_node.type == "class_definition":
parent_name = next(
(child.text.decode('utf-8')
for child in parent_node.children
if child.type == "identifier"),
None
)
return CodeChunk(
type=ChunkType.FUNCTION,
name=name,
content=source_code[node.start_byte:node.end_byte].decode('utf-8'),
start_line=node.start_point[0],
end_line=node.end_point[0],
file_path=str(file_path),
parent_name=parent_name,
docstring=self._get_preceding_docstring(node, source_code)
)
def parse_file(self, file_path: Path) -> List[CodeChunk]:
"""Parse a single Python file and return list of code chunks"""
if not self.supports_extension(file_path.suffix):
LOG.debug(f"Skipping file {file_path}: unsupported extension")
return []
try:
# Check file size limit
if file_path.stat().st_size > self.language_config.max_file_size:
LOG.warning(f"Skipping file {file_path}: exceeds size limit of {self.language_config.max_file_size} bytes")
return []
with open(file_path, 'rb') as f:
source_code = f.read()
if not self.parser:
raise RuntimeError("Parser not initialized. Call setup() first.")
tree = self.parser.parse(source_code)
# Check for overall syntax validity
if self._has_syntax_error(tree.root_node):
LOG.warning(f"Syntax errors found in {file_path}")
return []
chunks: List[CodeChunk] = []
# Process the syntax tree
for node in tree.root_node.children:
if node.type == "class_definition":
chunks.append(self._process_class(node, source_code, file_path))
elif node.type == "function_definition":
chunks.append(self._process_function(node, source_code, file_path))
else:
# Skip other node types for now
pass
return chunks
except Exception as e:
LOG.error(f"Error parsing file {file_path}: {str(e)}")
return []
def supports_extension(self, ext: str) -> bool:
"""Check if this parser supports a given file extension"""
return ext in self.language_config.file_extensions |