gabykim's picture
refactor package name to knowlang
60532a1
raw
history blame
6.07 kB
from typing import List, Optional
from pathlib import Path
import tree_sitter_python
from tree_sitter import Language, Parser, Node
from knowlang.parser.base.parser import LanguageParser
from knowlang.core.types import CodeChunk, ChunkType
from knowlang.utils.fancy_log import FancyLogger
LOG = FancyLogger(__name__)
class PythonParser(LanguageParser):
"""Python-specific implementation of LanguageParser"""
def setup(self) -> None:
"""Initialize tree-sitter with Python language support"""
self.language = Language(tree_sitter_python.language())
self.parser = Parser(self.language)
self.language_config = self.config.parser.languages["python"]
def _get_preceding_docstring(self, node: Node, source_code: bytes) -> Optional[str]:
"""Extract docstring from comments"""
docstring_parts = []
current_node = node.prev_sibling
while current_node:
if current_node.type == "comment":
comment = source_code[current_node.start_byte:current_node.end_byte].decode('utf-8')
docstring_parts.insert(0, comment)
elif current_node.type == "expression_statement":
string_node = current_node.children[0] if current_node.children else None
if string_node and string_node.type in ("string", "string_literal"):
docstring = source_code[string_node.start_byte:string_node.end_byte].decode('utf-8')
docstring_parts.insert(0, docstring)
break
elif current_node.type not in ("empty_statement", "newline"):
break
current_node = current_node.prev_sibling
return '\n'.join(docstring_parts) if docstring_parts else None
def _has_syntax_error(self, node: Node) -> bool:
"""Check if the node or its children contain syntax errors"""
if node.type == "ERROR":
return True
if node.has_error:
return True
return any(self._has_syntax_error(child) for child in node.children)
def _process_class(self, node: Node, source_code: bytes, file_path: Path) -> CodeChunk:
"""Process a class node and return a CodeChunk"""
name = next(
(child.text.decode('utf-8')
for child in node.children
if child.type == "identifier"),
None
)
if not name:
raise ValueError(f"Could not find class name in node: {node.text}")
return CodeChunk(
type=ChunkType.CLASS,
name=name,
content=source_code[node.start_byte:node.end_byte].decode('utf-8'),
start_line=node.start_point[0],
end_line=node.end_point[0],
file_path=str(file_path),
docstring=self._get_preceding_docstring(node, source_code)
)
def _process_function(self, node: Node, source_code: bytes, file_path: Path) -> CodeChunk:
"""Process a function node and return a CodeChunk"""
name = next(
(child.text.decode('utf-8')
for child in node.children
if child.type == "identifier"),
None
)
if not name:
raise ValueError(f"Could not find function name in node: {node.text}")
# Determine if this is a method within a class
parent_node = node.parent
parent_name = None
if parent_node and parent_node.type == "class_definition":
parent_name = next(
(child.text.decode('utf-8')
for child in parent_node.children
if child.type == "identifier"),
None
)
return CodeChunk(
type=ChunkType.FUNCTION,
name=name,
content=source_code[node.start_byte:node.end_byte].decode('utf-8'),
start_line=node.start_point[0],
end_line=node.end_point[0],
file_path=str(file_path),
parent_name=parent_name,
docstring=self._get_preceding_docstring(node, source_code)
)
def parse_file(self, file_path: Path) -> List[CodeChunk]:
"""Parse a single Python file and return list of code chunks"""
if not self.supports_extension(file_path.suffix):
LOG.debug(f"Skipping file {file_path}: unsupported extension")
return []
try:
# Check file size limit
if file_path.stat().st_size > self.language_config.max_file_size:
LOG.warning(f"Skipping file {file_path}: exceeds size limit of {self.language_config.max_file_size} bytes")
return []
with open(file_path, 'rb') as f:
source_code = f.read()
if not self.parser:
raise RuntimeError("Parser not initialized. Call setup() first.")
tree = self.parser.parse(source_code)
# Check for overall syntax validity
if self._has_syntax_error(tree.root_node):
LOG.warning(f"Syntax errors found in {file_path}")
return []
chunks: List[CodeChunk] = []
# Process the syntax tree
for node in tree.root_node.children:
if node.type == "class_definition":
chunks.append(self._process_class(node, source_code, file_path))
elif node.type == "function_definition":
chunks.append(self._process_function(node, source_code, file_path))
else:
# Skip other node types for now
pass
return chunks
except Exception as e:
LOG.error(f"Error parsing file {file_path}: {str(e)}")
return []
def supports_extension(self, ext: str) -> bool:
"""Check if this parser supports a given file extension"""
return ext in self.language_config.file_extensions