Skip to content

Index

English SDK for Apache Spark

SparkAI

Source code in pyspark_ai/pyspark_ai.py
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
class SparkAI:
    _HTTP_HEADER = {
        "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36"
        " (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
        "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
        "Accept-Language": "en-US,en;q=0.5",
    }

    def __init__(
        self,
        llm: Optional[BaseLanguageModel] = None,
        web_search_tool: Optional[Callable[[str], str]] = None,
        spark_session: Optional[SparkSession] = None,
        enable_cache: bool = True,
        cache_file_format: str = "json",
        cache_file_location: Optional[str] = None,
        vector_store_dir: Optional[str] = None,
        vector_store_max_gb: Optional[float] = 16,
        max_tokens_of_web_content: int = 3000,
        sample_rows_in_table_info: int = 3,
        verbose: bool = True,
    ) -> None:
        """
        Initialize the SparkAI object with the provided parameters.

        :param llm: LLM instance for selecting web search result
                                 and writing the ingestion SQL query.
        :param web_search_tool: optional function to perform web search,
                                Google search will be used if not provided
        :param spark_session: optional SparkSession, a new one will be created if not provided
        :param enable_cache: optional boolean, whether to enable caching of results
        :param cache_file_format: optional str, format for cache file if enabled
        :param vector_store_dir: optional str, directory path for vector similarity search files,
                                if storing to disk is desired
        :param vector_store_max_gb: optional float, max size of vector store dir in GB
        :param max_tokens_of_web_content: maximum tokens of web content after encoding
        :param sample_rows_in_table_info: number of rows to be sampled and shown in the table info.
                                        This is only used for SQL transform. To disable it, set it to 0.
        :param verbose: whether to print out the log
        """
        self._spark = spark_session or SparkSession.builder.getOrCreate()
        if llm is None:
            llm = ChatOpenAI(model_name="gpt-4", temperature=0)
        self._llm = llm
        self._web_search_tool = web_search_tool or self._default_web_search_tool
        if enable_cache:
            self._enable_cache = enable_cache
            if cache_file_location is not None:
                # if there is parameter setting for it, use the parameter
                self._cache_file_location = cache_file_location
            elif "AI_CACHE_FILE_LOCATION" in os.environ:
                # otherwise read from env variable AI_CACHE_FILE_LOCATION
                self._cache_file_location = os.environ["AI_CACHE_FILE_LOCATION"]
            else:
                # use default value "spark_ai_cache.json"
                self._cache_file_location = "spark_ai_cache.json"
            self._cache = Cache(
                cache_file_location=self._cache_file_location,
                file_format=cache_file_format,
            )
            self._web_search_tool = SearchToolWithCache(
                self._web_search_tool, self._cache
            ).search
        else:
            self._cache = None
        self._vector_store_dir = vector_store_dir
        self._vector_store_max_gb = vector_store_max_gb
        self._max_tokens_of_web_content = max_tokens_of_web_content
        self._search_llm_chain = self._create_llm_chain(prompt=SEARCH_PROMPT)
        self._ingestion_chain = self._create_llm_chain(prompt=SQL_PROMPT)
        self._explain_chain = self._create_llm_chain(prompt=EXPLAIN_DF_PROMPT)
        self._verify_chain = self._create_llm_chain(prompt=VERIFY_PROMPT)
        self._udf_chain = self._create_llm_chain(prompt=UDF_PROMPT)
        self._sample_rows_in_table_info = sample_rows_in_table_info
        self._verbose = verbose
        if verbose:
            self._logger = CodeLogger("spark_ai")
        else:
            self._logger = None
        self._sql_agent = None
        self._sql_chain = None

    def _create_llm_chain(self, prompt: BasePromptTemplate):
        if self._cache is None:
            return LLMChain(llm=self._llm, prompt=prompt)

        return LLMChainWithCache(llm=self._llm, prompt=prompt, cache=self._cache)

    @property
    def sql_chain(self):
        if self._sql_chain is None:
            self._sql_chain = SparkSQLChain(
                prompt=SQL_CHAIN_PROMPT,
                llm=self._llm,
                logger=self._logger,
                spark=self._spark,
            )
        return self._sql_chain

    @property
    def sql_agent(self):
        if self._sql_agent is None:
            self._sql_agent = self._create_sql_agent()
        return self._sql_agent

    def _create_sql_agent(self):
        # exclude SimilarValueTool if vector_store_dir not configured
        tools = (
            [
                QuerySparkSQLTool(spark=self._spark),
                QueryValidationTool(spark=self._spark),
                SimilarValueTool(
                    spark=self._spark,
                    vector_store_dir=self._vector_store_dir,
                    lru_vector_store=LRUVectorStore(
                        self._vector_store_dir, self._vector_store_max_gb
                    ),
                ),
            ]
            if self._vector_store_dir
            else [
                QuerySparkSQLTool(spark=self._spark),
                QueryValidationTool(spark=self._spark),
            ]
        )
        agent = ReActSparkSQLAgent.from_llm_and_tools(
            llm=self._llm, tools=tools, verbose=True
        )
        return AgentExecutor.from_agent_and_tools(
            agent=agent, tools=tools, verbose=True
        )

    @staticmethod
    def _generate_search_prompt(columns: Optional[List[str]]) -> str:
        return (
            f"The best search results should contain as many as possible of these info: {','.join(columns)}"
            if columns is not None and len(columns) > 0
            else ""
        )

    @staticmethod
    def _generate_sql_prompt(columns: Optional[List[str]]) -> str:
        return (
            f"The result view MUST contain following columns: {columns}"
            if columns is not None and len(columns) > 0
            else ""
        )

    @staticmethod
    def _default_web_search_tool(desc: str) -> str:
        search_wrapper = GoogleSearchAPIWrapper()
        return str(search_wrapper.results(query=desc, num_results=10))

    @staticmethod
    def _is_http_or_https_url(s: str):
        result = urlparse(s)  # Parse the URL
        # Check if the scheme is 'http' or 'https'
        return result.scheme in ["http", "https"]

    def log(self, message: str) -> None:
        if self._verbose:
            self._logger.info(message)

    def _trim_text_from_end(self, text: str, max_tokens: int) -> str:
        """
        Trim text from the end based on the maximum number of tokens allowed.

        :param text: text to trim
        :param max_tokens: maximum tokens allowed
        :return: trimmed text
        """
        import tiktoken

        encoding = tiktoken.get_encoding("cl100k_base")
        tokens = list(encoding.encode(text))
        if len(tokens) > max_tokens:
            tokens = tokens[:max_tokens]
        return encoding.decode(tokens)

    def _get_url_from_search_tool(
        self, desc: str, columns: Optional[List[str]], cache: bool
    ) -> str:
        search_result = self._web_search_tool(desc)
        search_columns_hint = self._generate_search_prompt(columns)
        # Run the LLM chain to pick the best search result
        tags = self._get_tags(cache)
        return self._search_llm_chain.run(
            tags=tags,
            query=desc,
            search_results=search_result,
            columns={search_columns_hint},
        )

    def _create_dataframe_with_llm(
        self,
        text: str,
        desc: str,
        columns: Optional[List[str]],
        cache: bool,
    ) -> DataFrame:
        clean_text = " ".join(text.split())
        web_content = self._trim_text_from_end(
            clean_text, self._max_tokens_of_web_content
        )

        sql_columns_hint = self._generate_sql_prompt(columns)

        # Run the LLM chain to get an ingestion SQL query
        tags = self._get_tags(cache)
        temp_view_name = random_view_name(web_content)
        llm_result = self._ingestion_chain.run(
            tags=tags,
            query=desc,
            web_content=web_content,
            view_name=temp_view_name,
            columns=sql_columns_hint,
        )
        sql_query = AIUtils.extract_code_blocks(llm_result)[0]
        # The actual view name used in the SQL query may be different from the
        # temp view name because of caching.
        view_name = SparkUtils.extract_view_name(sql_query)
        formatted_sql_query = CodeLogger.colorize_code(sql_query, "sql")
        self.log(f"SQL query for the ingestion:\n{formatted_sql_query}")
        self.log(f"Storing data into temp view: {view_name}\n")
        self._spark.sql(sql_query)
        return self._spark.table(view_name)

    def _get_df_explain(self, df: DataFrame, cache: bool) -> str:
        raw_analyzed_str = SparkUtils.get_analyzed_plan_from_explain(df)
        tags = self._get_tags(cache)
        return self._explain_chain.run(
            tags=tags, input=SparkUtils.trim_hash_id(raw_analyzed_str)
        )

    def _get_tags(self, cache: bool) -> Optional[List[str]]:
        if self._enable_cache and not cache:
            return SKIP_CACHE_TAGS
        return None

    def create_df(
        self, desc: str, columns: Optional[List[str]] = None, cache: bool = True
    ) -> DataFrame:
        """
        Create a Spark DataFrame by querying an LLM from web search result.

        :param desc: the description of the result DataFrame, which will be used for
                     web searching
        :param columns: the expected column names in the result DataFrame
        :param cache: If `True`, fetches cached data, if available. If `False`, retrieves fresh data and updates cache.

        :return: a Spark DataFrame
        """
        # check for necessary dependencies
        try:
            import requests
            import tiktoken
            from bs4 import BeautifulSoup
        except ImportError:
            raise Exception(
                "Dependencies for `ingestion` not found. To fix, run `pip install pyspark-ai[ingestion]`"
            )

        url = desc.strip()  # Remove leading and trailing whitespace
        is_url = self._is_http_or_https_url(url)
        # If the input is not a valid URL, use search tool to get the dataset.
        if not is_url:
            url = self._get_url_from_search_tool(desc, columns, cache)

        self.log(f"Parsing URL: {url}\n")
        try:
            response = requests.get(url, headers=self._HTTP_HEADER)
            response.raise_for_status()
        except requests.exceptions.HTTPError as http_err:
            self.log(f"HTTP error occurred: {http_err}")
            return
        except Exception as err:
            self.log(f"Other error occurred: {err}")
            return

        soup = BeautifulSoup(response.text, "html.parser")

        # add url and page content to cache
        if cache:
            if self._cache.lookup(key=url):
                page_content = self._cache.lookup(key=url)
            else:
                page_content = soup.get_text()
                self._cache.update(key=url, val=page_content)
        else:
            page_content = soup.get_text()

        # If the input is a URL link, use the title of web page as the
        # dataset's description.
        if is_url:
            desc = soup.title.string
        return self._create_dataframe_with_llm(page_content, desc, columns, cache)

    def _get_transform_sql_query_from_agent(
        self,
        temp_view_name: str,
        sample_vals_str: str,
        comment: str,
        desc: str,
    ) -> str:
        llm_result = self.sql_agent.run(
            view_name=temp_view_name,
            sample_vals=sample_vals_str,
            comment=comment,
            desc=desc,
        )
        sql_query_from_response = AIUtils.extract_code_blocks(llm_result)[0]
        return sql_query_from_response

    def _get_sql_query(
        self,
        temp_view_name: str,
        sample_vals_str: str,
        comment: str,
        desc: str,
    ) -> str:
        # If LLM is ChatOpenAI instance and the model is GPT-4, use the ReActSparkSQLAgent to generate the SQL query
        if isinstance(self._llm, ChatOpenAI) and self._llm.model_name == "gpt-4":
            return self._get_transform_sql_query_from_agent(
                temp_view_name, sample_vals_str, comment, desc
            )
        else:
            # Otherwise, generate the SQL query with a prompt with few-shot examples
            return self.sql_chain.run(
                view_name=temp_view_name,
                sample_vals=sample_vals_str,
                comment=comment,
                desc=desc,
            )

    def _get_transform_sql_query(self, df: DataFrame, desc: str, cache: bool) -> str:
        temp_view_name = random_view_name(df)
        create_temp_view_code = CodeLogger.colorize_code(
            f'df.createOrReplaceTempView("{temp_view_name}")', "python"
        )
        self.log(f"Creating temp view for the transform:\n{create_temp_view_code}")
        df.createOrReplaceTempView(temp_view_name)
        schema_lst = SparkUtils.get_df_schema(df)
        schema_str = "\n".join(schema_lst)
        sample_rows = SparkUtils.get_sample_spark_rows(
            df, self._sample_rows_in_table_info
        )
        schema_row_lst = []
        for index in range(len(schema_lst)):
            sample_vals = []
            for sample_row in sample_rows:
                sample_vals.append(sample_row[index])
            curr_schema_row = f"({schema_lst[index]}, {str(sample_vals)})"
            schema_row_lst.append(curr_schema_row)
        sample_vals_str = "\n".join([str(val) for val in schema_row_lst])
        comment = SparkUtils.get_table_comment(df, self._spark)

        if cache:
            cache_key = ReActSparkSQLAgent.cache_key(desc, schema_str)
            cached_result = self._cache.lookup(key=cache_key)
            if cached_result is not None:
                self.log("Using cached result for the transform:")
                self.log(CodeLogger.colorize_code(cached_result, "sql"))
                return replace_view_name(cached_result, temp_view_name)
            else:
                sql_query = self._get_sql_query(
                    temp_view_name, sample_vals_str, comment, desc
                )
                self._cache.update(key=cache_key, val=canonize_string(sql_query))
                return sql_query
        else:
            return self._get_sql_query(temp_view_name, sample_vals_str, comment, desc)

    def transform_df(self, df: DataFrame, desc: str, cache: bool = True) -> DataFrame:
        """
        This method applies a transformation to a provided Spark DataFrame,
        the specifics of which are determined by the 'desc' parameter.

        :param df: The Spark DataFrame that is to be transformed.
        :param desc: A natural language string that outlines the specific transformation to be applied on the DataFrame.
        :param cache: If `True`, fetches cached data, if available. If `False`, retrieves fresh data and updates cache.

        :return: Returns a new Spark DataFrame that is the result of applying the specified transformation
                 on the input DataFrame.
        """
        sql_query = self._get_transform_sql_query(df, desc, cache)
        return self._spark.sql(sql_query)

    def explain_df(self, df: DataFrame, cache: bool = True) -> str:
        """
        This method generates a natural language explanation of the SQL plan of the input Spark DataFrame.

        :param df: The Spark DataFrame to be explained.
        :param cache: If `True`, fetches cached data, if available. If `False`, retrieves fresh data and updates cache.

        :return: A string explanation of the DataFrame's SQL plan, detailing what the DataFrame is intended to retrieve.
        """
        explain_result = self._get_df_explain(df, cache)
        # If there is code block in the explain result, ignore it.
        if "```" in explain_result:
            summary = explain_result.split("```")[-1]
            return summary.strip()
        else:
            return explain_result

    def plot_df(
        self, df: DataFrame, desc: Optional[str] = None, cache: bool = True
    ) -> str:
        """
        Plot a Spark DataFrame, the specifics of which are determined by the `desc` parameter.
        If `desc` is not provided, the method will try to plot the DataFrame based on its schema.

        :param df: The PySpark dataframe to generate plotting code for.
        :param desc: An optional natural language string that outlines the specific transformation to be applied on the
                     DataFrame.
        :param cache: Whether to cache the dataframe or not. Default is True.

        :return: Returns the generated code as a string. If the generated code is not valid Python code, an empty string
                 is returned.
        """
        # check for necessary plot dependencies
        try:
            import pandas
            import plotly
            import pyarrow
        except ImportError:
            raise Exception(
                "Dependencies for `plot_df` not found. To fix, run `pip install pyspark-ai[plot]`"
            )
        instruction = f"The purpose of the plot: {desc}" if desc is not None else ""
        tags = self._get_tags(cache)
        plot_chain = PythonExecutor(
            df=DataFrameLike(df),
            prompt=PLOT_PROMPT,
            cache=self._cache,
            llm=self._llm,
            logger=self._logger,
        )
        return plot_chain.run(
            tags=tags,
            columns=SparkUtils.get_df_schema(df),
            instruction=instruction,
        )

    def verify_df(self, df: DataFrame, desc: str, cache: bool = True) -> bool:
        """
        This method creates and runs test cases for the provided PySpark dataframe transformation function.

        :param df: The Spark DataFrame to be verified
        :param desc: A description of the expectation to be verified
        :param cache: If `True`, fetches cached data, if available. If `False`, retrieves fresh data and updates cache.
        ...
        :return: True if transformation is valied and False otherwise.
        :rtpe: bool
        """
        tags = self._get_tags(cache)
        llm_output = self._verify_chain.run(tags=tags, df=df, desc=desc)

        codeblocks = AIUtils.extract_code_blocks(llm_output)
        llm_output = "\n".join(codeblocks)

        self.log(f"LLM Output:\n{llm_output}")

        formatted_code = CodeLogger.colorize_code(llm_output, "python")
        self.log(f"Generated code:\n{formatted_code}")

        locals_ = {}
        try:
            exec(compile(llm_output, "verify_df-CodeGen", "exec"), {"df": df}, locals_)
        except Exception as e:
            self.log("Could not evaluate Python code")
            self.log(str(e))
            return False
        self.log(f"\nResult: {locals_.get('result')}")

        return locals_.get("result", False)

    def udf(self, func: Callable) -> Callable:
        from inspect import signature

        desc = func.__doc__
        func_signature = str(signature(func))
        input_args_types = func_signature.split("->")[0].strip()
        return_type = func_signature.split("->")[1].strip()
        udf_name = func.__name__

        code = self._udf_chain.run(
            input_args_types=input_args_types,
            desc=desc,
            return_type=return_type,
            udf_name=udf_name,
        )

        formatted_code = CodeLogger.colorize_code(code, "python")
        self.log(f"Creating following Python UDF:\n{formatted_code}")

        locals_ = {}
        try:
            exec(compile(code, "udf-CodeGen", "exec"), globals(), locals_)
        except Exception as e:
            raise Exception("Could not evaluate Python code", e)
        return locals_[udf_name]

    def activate(self):
        """
        Activates AI utility functions for Spark DataFrame.
        """
        DataFrame.ai = AIUtils(self)
        # Patch the Spark Connect DataFrame as well.
        try:
            from pyspark.sql.connect.dataframe import DataFrame as CDataFrame

            CDataFrame.ai = AIUtils(self)
        except ImportError:
            self.log(
                "The pyspark.sql.connect.dataframe module could not be imported. "
                "This might be due to your PySpark version being below 3.4."
            )

    def commit(self):
        """
        Commit the staging in-memory cache into persistent cache, if cache is enabled.
        """
        if self._cache is not None:
            self._cache.commit()

__init__(llm=None, web_search_tool=None, spark_session=None, enable_cache=True, cache_file_format='json', cache_file_location=None, vector_store_dir=None, vector_store_max_gb=16, max_tokens_of_web_content=3000, sample_rows_in_table_info=3, verbose=True)

Initialize the SparkAI object with the provided parameters.

Parameters:

Name Type Description Default
llm Optional[BaseLanguageModel]

LLM instance for selecting web search result and writing the ingestion SQL query.

None
web_search_tool Optional[Callable[[str], str]]

optional function to perform web search, Google search will be used if not provided

None
spark_session Optional[SparkSession]

optional SparkSession, a new one will be created if not provided

None
enable_cache bool

optional boolean, whether to enable caching of results

True
cache_file_format str

optional str, format for cache file if enabled

'json'
vector_store_dir Optional[str]

optional str, directory path for vector similarity search files, if storing to disk is desired

None
vector_store_max_gb Optional[float]

optional float, max size of vector store dir in GB

16
max_tokens_of_web_content int

maximum tokens of web content after encoding

3000
sample_rows_in_table_info int

number of rows to be sampled and shown in the table info. This is only used for SQL transform. To disable it, set it to 0.

3
verbose bool

whether to print out the log

True
Source code in pyspark_ai/pyspark_ai.py
def __init__(
    self,
    llm: Optional[BaseLanguageModel] = None,
    web_search_tool: Optional[Callable[[str], str]] = None,
    spark_session: Optional[SparkSession] = None,
    enable_cache: bool = True,
    cache_file_format: str = "json",
    cache_file_location: Optional[str] = None,
    vector_store_dir: Optional[str] = None,
    vector_store_max_gb: Optional[float] = 16,
    max_tokens_of_web_content: int = 3000,
    sample_rows_in_table_info: int = 3,
    verbose: bool = True,
) -> None:
    """
    Initialize the SparkAI object with the provided parameters.

    :param llm: LLM instance for selecting web search result
                             and writing the ingestion SQL query.
    :param web_search_tool: optional function to perform web search,
                            Google search will be used if not provided
    :param spark_session: optional SparkSession, a new one will be created if not provided
    :param enable_cache: optional boolean, whether to enable caching of results
    :param cache_file_format: optional str, format for cache file if enabled
    :param vector_store_dir: optional str, directory path for vector similarity search files,
                            if storing to disk is desired
    :param vector_store_max_gb: optional float, max size of vector store dir in GB
    :param max_tokens_of_web_content: maximum tokens of web content after encoding
    :param sample_rows_in_table_info: number of rows to be sampled and shown in the table info.
                                    This is only used for SQL transform. To disable it, set it to 0.
    :param verbose: whether to print out the log
    """
    self._spark = spark_session or SparkSession.builder.getOrCreate()
    if llm is None:
        llm = ChatOpenAI(model_name="gpt-4", temperature=0)
    self._llm = llm
    self._web_search_tool = web_search_tool or self._default_web_search_tool
    if enable_cache:
        self._enable_cache = enable_cache
        if cache_file_location is not None:
            # if there is parameter setting for it, use the parameter
            self._cache_file_location = cache_file_location
        elif "AI_CACHE_FILE_LOCATION" in os.environ:
            # otherwise read from env variable AI_CACHE_FILE_LOCATION
            self._cache_file_location = os.environ["AI_CACHE_FILE_LOCATION"]
        else:
            # use default value "spark_ai_cache.json"
            self._cache_file_location = "spark_ai_cache.json"
        self._cache = Cache(
            cache_file_location=self._cache_file_location,
            file_format=cache_file_format,
        )
        self._web_search_tool = SearchToolWithCache(
            self._web_search_tool, self._cache
        ).search
    else:
        self._cache = None
    self._vector_store_dir = vector_store_dir
    self._vector_store_max_gb = vector_store_max_gb
    self._max_tokens_of_web_content = max_tokens_of_web_content
    self._search_llm_chain = self._create_llm_chain(prompt=SEARCH_PROMPT)
    self._ingestion_chain = self._create_llm_chain(prompt=SQL_PROMPT)
    self._explain_chain = self._create_llm_chain(prompt=EXPLAIN_DF_PROMPT)
    self._verify_chain = self._create_llm_chain(prompt=VERIFY_PROMPT)
    self._udf_chain = self._create_llm_chain(prompt=UDF_PROMPT)
    self._sample_rows_in_table_info = sample_rows_in_table_info
    self._verbose = verbose
    if verbose:
        self._logger = CodeLogger("spark_ai")
    else:
        self._logger = None
    self._sql_agent = None
    self._sql_chain = None

activate()

Activates AI utility functions for Spark DataFrame.

Source code in pyspark_ai/pyspark_ai.py
def activate(self):
    """
    Activates AI utility functions for Spark DataFrame.
    """
    DataFrame.ai = AIUtils(self)
    # Patch the Spark Connect DataFrame as well.
    try:
        from pyspark.sql.connect.dataframe import DataFrame as CDataFrame

        CDataFrame.ai = AIUtils(self)
    except ImportError:
        self.log(
            "The pyspark.sql.connect.dataframe module could not be imported. "
            "This might be due to your PySpark version being below 3.4."
        )

commit()

Commit the staging in-memory cache into persistent cache, if cache is enabled.

Source code in pyspark_ai/pyspark_ai.py
def commit(self):
    """
    Commit the staging in-memory cache into persistent cache, if cache is enabled.
    """
    if self._cache is not None:
        self._cache.commit()

create_df(desc, columns=None, cache=True)

Create a Spark DataFrame by querying an LLM from web search result.

Parameters:

Name Type Description Default
desc str

the description of the result DataFrame, which will be used for web searching

required
columns Optional[List[str]]

the expected column names in the result DataFrame

None
cache bool

If True, fetches cached data, if available. If False, retrieves fresh data and updates cache.

True

Returns:

Type Description
DataFrame

a Spark DataFrame

Source code in pyspark_ai/pyspark_ai.py
def create_df(
    self, desc: str, columns: Optional[List[str]] = None, cache: bool = True
) -> DataFrame:
    """
    Create a Spark DataFrame by querying an LLM from web search result.

    :param desc: the description of the result DataFrame, which will be used for
                 web searching
    :param columns: the expected column names in the result DataFrame
    :param cache: If `True`, fetches cached data, if available. If `False`, retrieves fresh data and updates cache.

    :return: a Spark DataFrame
    """
    # check for necessary dependencies
    try:
        import requests
        import tiktoken
        from bs4 import BeautifulSoup
    except ImportError:
        raise Exception(
            "Dependencies for `ingestion` not found. To fix, run `pip install pyspark-ai[ingestion]`"
        )

    url = desc.strip()  # Remove leading and trailing whitespace
    is_url = self._is_http_or_https_url(url)
    # If the input is not a valid URL, use search tool to get the dataset.
    if not is_url:
        url = self._get_url_from_search_tool(desc, columns, cache)

    self.log(f"Parsing URL: {url}\n")
    try:
        response = requests.get(url, headers=self._HTTP_HEADER)
        response.raise_for_status()
    except requests.exceptions.HTTPError as http_err:
        self.log(f"HTTP error occurred: {http_err}")
        return
    except Exception as err:
        self.log(f"Other error occurred: {err}")
        return

    soup = BeautifulSoup(response.text, "html.parser")

    # add url and page content to cache
    if cache:
        if self._cache.lookup(key=url):
            page_content = self._cache.lookup(key=url)
        else:
            page_content = soup.get_text()
            self._cache.update(key=url, val=page_content)
    else:
        page_content = soup.get_text()

    # If the input is a URL link, use the title of web page as the
    # dataset's description.
    if is_url:
        desc = soup.title.string
    return self._create_dataframe_with_llm(page_content, desc, columns, cache)

explain_df(df, cache=True)

This method generates a natural language explanation of the SQL plan of the input Spark DataFrame.

Parameters:

Name Type Description Default
df DataFrame

The Spark DataFrame to be explained.

required
cache bool

If True, fetches cached data, if available. If False, retrieves fresh data and updates cache.

True

Returns:

Type Description
str

A string explanation of the DataFrame's SQL plan, detailing what the DataFrame is intended to retrieve.

Source code in pyspark_ai/pyspark_ai.py
def explain_df(self, df: DataFrame, cache: bool = True) -> str:
    """
    This method generates a natural language explanation of the SQL plan of the input Spark DataFrame.

    :param df: The Spark DataFrame to be explained.
    :param cache: If `True`, fetches cached data, if available. If `False`, retrieves fresh data and updates cache.

    :return: A string explanation of the DataFrame's SQL plan, detailing what the DataFrame is intended to retrieve.
    """
    explain_result = self._get_df_explain(df, cache)
    # If there is code block in the explain result, ignore it.
    if "```" in explain_result:
        summary = explain_result.split("```")[-1]
        return summary.strip()
    else:
        return explain_result

plot_df(df, desc=None, cache=True)

Plot a Spark DataFrame, the specifics of which are determined by the desc parameter. If desc is not provided, the method will try to plot the DataFrame based on its schema.

Parameters:

Name Type Description Default
df DataFrame

The PySpark dataframe to generate plotting code for.

required
desc Optional[str]

An optional natural language string that outlines the specific transformation to be applied on the DataFrame.

None
cache bool

Whether to cache the dataframe or not. Default is True.

True

Returns:

Type Description
str

Returns the generated code as a string. If the generated code is not valid Python code, an empty string is returned.

Source code in pyspark_ai/pyspark_ai.py
def plot_df(
    self, df: DataFrame, desc: Optional[str] = None, cache: bool = True
) -> str:
    """
    Plot a Spark DataFrame, the specifics of which are determined by the `desc` parameter.
    If `desc` is not provided, the method will try to plot the DataFrame based on its schema.

    :param df: The PySpark dataframe to generate plotting code for.
    :param desc: An optional natural language string that outlines the specific transformation to be applied on the
                 DataFrame.
    :param cache: Whether to cache the dataframe or not. Default is True.

    :return: Returns the generated code as a string. If the generated code is not valid Python code, an empty string
             is returned.
    """
    # check for necessary plot dependencies
    try:
        import pandas
        import plotly
        import pyarrow
    except ImportError:
        raise Exception(
            "Dependencies for `plot_df` not found. To fix, run `pip install pyspark-ai[plot]`"
        )
    instruction = f"The purpose of the plot: {desc}" if desc is not None else ""
    tags = self._get_tags(cache)
    plot_chain = PythonExecutor(
        df=DataFrameLike(df),
        prompt=PLOT_PROMPT,
        cache=self._cache,
        llm=self._llm,
        logger=self._logger,
    )
    return plot_chain.run(
        tags=tags,
        columns=SparkUtils.get_df_schema(df),
        instruction=instruction,
    )

transform_df(df, desc, cache=True)

This method applies a transformation to a provided Spark DataFrame, the specifics of which are determined by the 'desc' parameter.

Parameters:

Name Type Description Default
df DataFrame

The Spark DataFrame that is to be transformed.

required
desc str

A natural language string that outlines the specific transformation to be applied on the DataFrame.

required
cache bool

If True, fetches cached data, if available. If False, retrieves fresh data and updates cache.

True

Returns:

Type Description
DataFrame

Returns a new Spark DataFrame that is the result of applying the specified transformation on the input DataFrame.

Source code in pyspark_ai/pyspark_ai.py
def transform_df(self, df: DataFrame, desc: str, cache: bool = True) -> DataFrame:
    """
    This method applies a transformation to a provided Spark DataFrame,
    the specifics of which are determined by the 'desc' parameter.

    :param df: The Spark DataFrame that is to be transformed.
    :param desc: A natural language string that outlines the specific transformation to be applied on the DataFrame.
    :param cache: If `True`, fetches cached data, if available. If `False`, retrieves fresh data and updates cache.

    :return: Returns a new Spark DataFrame that is the result of applying the specified transformation
             on the input DataFrame.
    """
    sql_query = self._get_transform_sql_query(df, desc, cache)
    return self._spark.sql(sql_query)

verify_df(df, desc, cache=True)

This method creates and runs test cases for the provided PySpark dataframe transformation function.

:rtpe: bool

Parameters:

Name Type Description Default
df DataFrame

The Spark DataFrame to be verified

required
desc str

A description of the expectation to be verified

required
cache bool

If True, fetches cached data, if available. If False, retrieves fresh data and updates cache. ...

True

Returns:

Type Description
bool

True if transformation is valied and False otherwise.

Source code in pyspark_ai/pyspark_ai.py
def verify_df(self, df: DataFrame, desc: str, cache: bool = True) -> bool:
    """
    This method creates and runs test cases for the provided PySpark dataframe transformation function.

    :param df: The Spark DataFrame to be verified
    :param desc: A description of the expectation to be verified
    :param cache: If `True`, fetches cached data, if available. If `False`, retrieves fresh data and updates cache.
    ...
    :return: True if transformation is valied and False otherwise.
    :rtpe: bool
    """
    tags = self._get_tags(cache)
    llm_output = self._verify_chain.run(tags=tags, df=df, desc=desc)

    codeblocks = AIUtils.extract_code_blocks(llm_output)
    llm_output = "\n".join(codeblocks)

    self.log(f"LLM Output:\n{llm_output}")

    formatted_code = CodeLogger.colorize_code(llm_output, "python")
    self.log(f"Generated code:\n{formatted_code}")

    locals_ = {}
    try:
        exec(compile(llm_output, "verify_df-CodeGen", "exec"), {"df": df}, locals_)
    except Exception as e:
        self.log("Could not evaluate Python code")
        self.log(str(e))
        return False
    self.log(f"\nResult: {locals_.get('result')}")

    return locals_.get("result", False)