Files
CljElixir/lib/clj_elixir/repl.ex
2026-03-09 23:09:46 -04:00

257 lines
8.0 KiB
Elixir

defmodule CljElixir.REPL do
@moduledoc """
CljElixir Read-Eval-Print Loop engine.
Maintains state across evaluations: bindings persist,
modules defined in one evaluation are available in the next.
Tracks the current namespace (`ns`) so that bare `defn`/`def` forms
are merged into the active module and the module is recompiled
incrementally.
"""
defstruct bindings: [],
history: [],
counter: 1,
env: nil,
current_ns: nil,
module_defs: %{}
@doc "Create a new REPL state"
def new do
Code.compiler_options(ignore_module_conflict: true)
%__MODULE__{}
end
@doc "Return the current namespace name (defaults to \"user\")"
def current_ns(%__MODULE__{current_ns: ns}), do: ns || "user"
@doc """
Evaluate a CljElixir source string in the given REPL state.
Returns {:ok, result_string, new_state} or {:error, error_string, new_state}.
"""
def eval(source, state) do
case CljElixir.Reader.read_string(source) do
{:ok, forms} ->
has_ns = Enum.any?(forms, &ns_form?/1)
has_defs = Enum.any?(forms, &def_form?/1)
cond do
has_ns ->
eval_with_ns(forms, source, state)
has_defs and state.current_ns != nil ->
eval_in_ns(forms, source, state)
true ->
eval_plain(source, state)
end
{:error, reason} ->
error_msg = if is_binary(reason), do: reason, else: inspect(reason)
{:error, "Read error: #{error_msg}", %{state | counter: state.counter + 1}}
end
end
@doc "Check if input has balanced delimiters (parens, brackets, braces)"
def balanced?(input) do
input
|> String.graphemes()
|> count_delimiters(0, 0, 0, false, false)
end
# ---------------------------------------------------------------------------
# Eval strategies
# ---------------------------------------------------------------------------
# Full ns block: set namespace, capture defs, compile normally
defp eval_with_ns(forms, source, state) do
ns_name = extract_ns_name(forms)
new_defs = collect_defs(forms)
opts = [bindings: state.bindings, file: "repl"]
case CljElixir.Compiler.eval_string(source, opts) do
{:ok, result, new_bindings} ->
new_state = %{state |
bindings: new_bindings,
current_ns: ns_name,
module_defs: new_defs,
history: [source | state.history],
counter: state.counter + 1
}
{:ok, CljElixir.Printer.pr_str(result), new_state}
{:error, errors} ->
{:error, format_errors(errors), %{state | counter: state.counter + 1}}
end
end
# Bare defs in active namespace: merge into module_defs and recompile module
defp eval_in_ns(forms, source, state) do
{new_def_forms, exprs} = Enum.split_with(forms, &def_form?/1)
# Merge new defs into accumulated module_defs (keyed by name)
merged_defs =
Enum.reduce(new_def_forms, state.module_defs, fn form, acc ->
name = extract_def_name(form)
Map.put(acc, name, form)
end)
# Reconstruct: ns + all accumulated defs + current expressions
ns_form = make_ns_form(state.current_ns)
all_forms = [ns_form | Map.values(merged_defs)] ++ exprs
opts = [bindings: state.bindings, file: "repl"]
case CljElixir.Compiler.eval_forms(all_forms, opts) do
{:ok, result, new_bindings} ->
result_str =
if exprs == [] do
# Def-only: show var-like representation
new_def_forms
|> Enum.map(&extract_def_name/1)
|> Enum.map_join(" ", &"#'#{state.current_ns}/#{&1}")
else
CljElixir.Printer.pr_str(result)
end
new_state = %{state |
bindings: new_bindings,
module_defs: merged_defs,
history: [source | state.history],
counter: state.counter + 1
}
{:ok, result_str, new_state}
{:error, errors} ->
{:error, format_errors(errors), %{state | counter: state.counter + 1}}
end
end
# No ns context: eval as-is (legacy / ad-hoc expressions)
defp eval_plain(source, state) do
opts = [bindings: state.bindings, file: "repl"]
case CljElixir.Compiler.eval_string(source, opts) do
{:ok, result, new_bindings} ->
new_state = %{state |
bindings: new_bindings,
history: [source | state.history],
counter: state.counter + 1
}
{:ok, CljElixir.Printer.pr_str(result), new_state}
{:error, errors} ->
{:error, format_errors(errors), %{state | counter: state.counter + 1}}
end
end
# ---------------------------------------------------------------------------
# Form classification helpers
# ---------------------------------------------------------------------------
defp ns_form?({:list, _, [{:symbol, _, "ns"} | _]}), do: true
defp ns_form?(_), do: false
defp def_form?({:list, _, [{:symbol, _, name} | _]})
when name in ~w(defn defn- def defprotocol defrecord extend-type
extend-protocol reify defmacro use),
do: true
defp def_form?({:list, _, [{:symbol, _, "m/=>"} | _]}), do: true
defp def_form?(_), do: false
defp extract_ns_name(forms) do
Enum.find_value(forms, fn
{:list, _, [{:symbol, _, "ns"}, {:symbol, _, name} | _]} -> name
_ -> nil
end)
end
defp collect_defs(forms) do
forms
|> Enum.filter(&def_form?/1)
|> Enum.reduce(%{}, fn form, acc ->
name = extract_def_name(form)
Map.put(acc, name, form)
end)
end
defp extract_def_name({:list, _, [{:symbol, _, _}, {:symbol, _, name} | _]}), do: name
defp extract_def_name(form), do: "anon_#{:erlang.phash2(form)}"
defp make_ns_form(ns_name) do
{:list, %{line: 0, col: 0}, [
{:symbol, %{line: 0, col: 0}, "ns"},
{:symbol, %{line: 0, col: 0}, ns_name}
]}
end
# ---------------------------------------------------------------------------
# Delimiter balancing
# ---------------------------------------------------------------------------
defp count_delimiters([], parens, brackets, braces, _in_string, _escape) do
# Negative counts mean excess closing delimiters — let the reader report the error
parens < 0 or brackets < 0 or braces < 0 or
(parens == 0 and brackets == 0 and braces == 0)
end
defp count_delimiters([char | rest], p, b, br, in_string, escape) do
cond do
escape ->
count_delimiters(rest, p, b, br, in_string, false)
char == "\\" and in_string ->
count_delimiters(rest, p, b, br, in_string, true)
char == "\"" and not in_string ->
count_delimiters(rest, p, b, br, true, false)
char == "\"" and in_string ->
count_delimiters(rest, p, b, br, false, false)
in_string ->
count_delimiters(rest, p, b, br, true, false)
char == ";" ->
rest_after_newline = Enum.drop_while(rest, &(&1 != "\n"))
count_delimiters(rest_after_newline, p, b, br, false, false)
char == "(" -> count_delimiters(rest, p + 1, b, br, false, false)
char == ")" -> count_delimiters(rest, p - 1, b, br, false, false)
char == "[" -> count_delimiters(rest, p, b + 1, br, false, false)
char == "]" -> count_delimiters(rest, p, b - 1, br, false, false)
char == "{" -> count_delimiters(rest, p, b, br + 1, false, false)
char == "}" -> count_delimiters(rest, p, b, br - 1, false, false)
true -> count_delimiters(rest, p, b, br, false, false)
end
end
# ---------------------------------------------------------------------------
# Error formatting
# ---------------------------------------------------------------------------
defp format_errors(errors) when is_list(errors) do
Enum.map_join(errors, "\n", fn
%{message: msg, line: line} when is_integer(line) and line > 0 ->
"Error on line #{line}: #{msg}"
%{message: msg} ->
"Error: #{msg}"
other ->
"Error: #{inspect(other)}"
end)
end
defp format_errors(other) do
"Error: #{inspect(other)}"
end
end