root/tags/release-1.0-beta/tsp.c

Revision 39, 11.5 KB (checked in by anton, 3 years ago)

1.0.0b tag added

Line 
1/*
2 * Traveling Salesman Problem solution algorithm for PostgreSQL
3 *
4 * Copyright (c) 2006 Anton A. Patrushev, Orkney, Inc.
5 *
6 * This program is free software; you can redistribute it and/or modify
7 * it under the terms of the GNU General Public License as published by
8 * the Free Software Foundation; either version 2 of the License, or
9 * (at your option) any later version.
10 *
11 * This program is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 * GNU General Public License for more details.
15 *
16 * You should have received a copy of the GNU General Public License
17 * along with this program; if not, write to the Free Software
18 * Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
19 *
20 */
21
22#include "tsp.h"
23
24#include "postgres.h"
25#include "executor/spi.h"
26#include "funcapi.h"
27
28#include "string.h"
29#include "math.h"
30
31#ifdef PG_MODULE_MAGIC
32PG_MODULE_MAGIC;
33#endif
34
35
36// ------------------------------------------------------------------------
37
38/*
39 * Define this to have profiling enabled
40 */
41//#define PROFILE
42
43#ifdef PROFILE
44#include <sys/time.h>
45
46struct timeval prof_tsp, prof_store, prof_extract, prof_total;
47long proftime[5];
48long profipts1, profipts2, profopts;
49#define profstart(x) do { gettimeofday(&x, NULL); } while (0);
50#define profstop(n, x) do { struct timeval _profstop;   \
51        long _proftime;                         \
52        gettimeofday(&_profstop, NULL);                         \
53        _proftime = ( _profstop.tv_sec*1000000+_profstop.tv_usec) -     \
54                ( x.tv_sec*1000000+x.tv_usec); \
55        elog(NOTICE, \
56                "PRF(%s) %lu (%f ms)", \
57                (n), \
58             _proftime, _proftime / 1000.0);    \
59        } while (0);
60
61#else
62
63#define profstart(x) do { } while (0);
64#define profstop(n, x) do { } while (0);
65
66#endif // PROFILE
67
68// ------------------------------------------------------------------------
69
70Datum tsp(PG_FUNCTION_ARGS);
71
72#undef DEBUG
73//#define DEBUG 1
74
75#ifdef DEBUG
76#define DBG(format, arg...)                     \
77    elog(NOTICE, format , ## arg)
78#else
79#define DBG(format, arg...) do { ; } while (0)
80#endif
81
82// The number of tuples to fetch from the SPI cursor at each iteration
83#define TUPLIMIT 1000
84
85// Apologies for using fixed-length arrays.  But this is an example, not
86// production code ;)
87//#define MAX_TOWNS 40
88
89float DISTANCE[MAX_TOWNS][MAX_TOWNS];
90float x[MAX_TOWNS],y[MAX_TOWNS];
91int total_tuples;
92
93
94static char *
95text2char(text *in)
96{
97  char *out = (char*)palloc(VARSIZE(in));
98
99  memcpy(out, VARDATA(in), VARSIZE(in) - VARHDRSZ);
100  out[VARSIZE(in) - VARHDRSZ] = '\0';
101  return out;
102}
103
104static int
105finish(int code, int ret)
106{
107  code = SPI_finish();
108  if (code  != SPI_OK_FINISH )
109  {
110    elog(ERROR,"couldn't disconnect from SPI");
111    return -1 ;
112  }
113  return ret;
114}
115                 
116
117typedef struct point_columns
118{
119  int id;
120  float8 x;
121  float8 y;
122} point_columns_t;
123
124
125static int
126fetch_point_columns(SPITupleTable *tuptable, point_columns_t *point_columns)
127{
128  DBG("Fetching point");
129
130  point_columns->id = SPI_fnumber(SPI_tuptable->tupdesc, "source_id");
131  point_columns->x = SPI_fnumber(SPI_tuptable->tupdesc, "x");
132  point_columns->y = SPI_fnumber(SPI_tuptable->tupdesc, "y");
133  if (point_columns->id == SPI_ERROR_NOATTRIBUTE ||
134      point_columns->x == SPI_ERROR_NOATTRIBUTE ||
135      point_columns->y == SPI_ERROR_NOATTRIBUTE)
136    {
137      elog(ERROR, "Error, query must return columns "
138           "'source_id', 'x' and 'y'");
139      return -1;
140    }
141   
142  DBG("* Point %i [%f, %f]", point_columns->id, point_columns->x,
143      point_columns->y);
144
145  return 0;
146}
147
148static void
149fetch_point(HeapTuple *tuple, TupleDesc *tupdesc,
150            point_columns_t *point_columns, point_t *point)
151{
152  Datum binval;
153  bool isnull;
154
155  DBG("inside fetch_point\n");
156
157  binval = SPI_getbinval(*tuple, *tupdesc, point_columns->id, &isnull);
158  DBG("Got id\n");
159
160  if (isnull)
161    elog(ERROR, "id contains a null value");
162
163  point->id = DatumGetInt32(binval);
164
165  DBG("id = %i\n", point->id);
166
167  binval = SPI_getbinval(*tuple, *tupdesc, point_columns->x, &isnull);
168  if (isnull)
169    elog(ERROR, "x contains a null value");
170
171  point->x = DatumGetFloat8(binval);
172
173  DBG("x = %f\n", point->x);
174
175  binval = SPI_getbinval(*tuple, *tupdesc, point_columns->y, &isnull);
176
177  if (isnull)
178    elog(ERROR, "y contains a null value");
179
180  point->y = DatumGetFloat8(binval);
181
182  DBG("y = %f\n", point->y);
183}
184
185
186static int solve_tsp(char* sql, char* p_ids,
187                     int source, path_element_t* path)
188{
189  int SPIcode;
190  void *SPIplan;
191  Portal SPIportal;
192  bool moredata = TRUE;
193  int ntuples;
194
195  //todo replace path (vector of path_element_t) with this array
196  int ids[MAX_TOWNS];
197
198  point_t *points=NULL;
199  point_columns_t point_columns = {id: -1, x: -1, y:-1};
200
201  char *err_msg = NULL;
202  int ret = -1;
203   
204  char *p;
205  int   z = 0;
206 
207  int    tt, cc;
208  double dx, dy;
209  float  fit=0.0;
210
211  DBG("inside tsp\n");
212
213  //int total_tuples = 0;
214  total_tuples = 0;
215
216  p = strtok(p_ids, ",");
217  while(p != NULL)
218    {
219      //      ((path_element_t*)path)[z].vertex_id = atoi(p);
220      ids[z]=atoi(p);
221      p = strtok(NULL, ",");
222      z++;
223      if(z >= MAX_TOWNS)
224      {
225        elog(ERROR, "Number of points exeeds max number.");
226        break;
227      }
228    }
229   
230  DBG("ZZZ %i\n",z);
231  DBG("start tsp\n");
232       
233  SPIcode = SPI_connect();
234
235  if (SPIcode  != SPI_OK_CONNECT)
236    {
237      elog(ERROR, "tsp: couldn't open a connection to SPI");
238      return -1;
239    }
240
241  SPIplan = SPI_prepare(sql, 0, NULL);
242
243  if (SPIplan  == NULL)
244    {
245      elog(ERROR, "tsp: couldn't create query plan via SPI");
246      return -1;
247    }
248
249  if ((SPIportal = SPI_cursor_open(NULL, SPIplan, NULL, NULL, true)) == NULL)
250    {
251      elog(ERROR, "tsp: SPI_cursor_open('%s') returns NULL", sql);
252      return -1;
253    }
254   
255  DBG("Query: %s\n",sql);
256  DBG("Query executed\n");
257
258  while (moredata == TRUE)
259    {
260      SPI_cursor_fetch(SPIportal, TRUE, TUPLIMIT);
261
262      if (point_columns.id == -1)
263        {
264          if (fetch_point_columns(SPI_tuptable, &point_columns) == -1)
265            return finish(SPIcode, ret);
266        }
267
268      ntuples = SPI_processed;
269
270      total_tuples += ntuples;
271
272      DBG("Tuples: %i\n", total_tuples);
273
274      if (!points)
275        points = palloc(total_tuples * sizeof(point_t));
276      else
277        points = repalloc(points, total_tuples * sizeof(point_t));
278                                       
279      if (points == NULL)
280        {
281          elog(ERROR, "Out of memory");
282          return finish(SPIcode, ret);
283        }
284
285      if (ntuples > 0)
286        {
287          int t;
288          SPITupleTable *tuptable = SPI_tuptable;
289          TupleDesc tupdesc = SPI_tuptable->tupdesc;
290
291          DBG("Got tuple desc\n");
292               
293          for (t = 0; t < ntuples; t++)
294            {
295              HeapTuple tuple = tuptable->vals[t];
296              DBG("Before point fetched\n");
297              fetch_point(&tuple, &tupdesc, &point_columns,
298                          &points[total_tuples - ntuples + t]);
299              DBG("Point fetched\n");
300            }
301
302          SPI_freetuptable(tuptable);
303        }
304      else
305        {
306          moredata = FALSE;
307        }                                       
308    }
309
310 
311  DBG("Calling TSP\n");
312       
313  profstop("extract", prof_extract);
314  profstart(prof_tsp);
315
316  DBG("Total tuples: %i\n", total_tuples);
317
318  for(tt=0;tt<total_tuples;++tt)
319    {
320      //((path_element_t*)path)[tt].vertex_id = points[tt].id;
321      ids[tt] = points[tt].id;
322      x[tt] = points[tt].x;
323      y[tt] = points[tt].y;
324 
325      DBG("Point at %i: %i [%f, %f]\n",  tt, ids[tt], x[tt], y[tt]);
326           
327      // ((path_element_t*)path)[tt].vertex_id, x[tt], y[tt]);
328
329      for(cc=0;cc<total_tuples;++cc)
330        {
331          dx=x[tt]-x[cc]; dy=y[tt]-y[cc];
332          DISTANCE[tt][cc] = DISTANCE[cc][tt] = sqrt(dx*dx+dy*dy);
333        }
334    }
335
336  DBG("DISTANCE counted\n");
337  pfree(points);
338   
339  //ret = find_tsp_solution(total_tuples, DISTANCE,
340  //   path, source, &fit, &err_msg);
341
342  ret = find_tsp_solution(total_tuples, DISTANCE, ids,
343                          source, &fit, err_msg);
344
345  for(tt=0;tt<total_tuples;++tt)
346    {
347      ((path_element_t*)path)[tt].vertex_id = ids[tt];
348    }
349   
350  DBG("TSP solved!\n");
351  DBG("Score: %f\n", fit);
352
353  profstop("tsp", prof_tsp);
354  profstart(prof_store);
355
356  DBG("Profile changed and ret is %i", ret);
357
358  if (ret < 0)
359    {
360      //elog(ERROR, "Error computing path: %s", err_msg);
361      ereport(ERROR, (errcode(ERRCODE_E_R_E_CONTAINING_SQL_NOT_PERMITTED), errmsg("Error computing path: %s", err_msg)));
362    }
363
364  return finish(SPIcode, ret);   
365}
366
367PG_FUNCTION_INFO_V1(tsp);
368Datum
369tsp(PG_FUNCTION_ARGS)
370{
371  FuncCallContext     *funcctx;
372  int                  call_cntr;
373  int                  max_calls;
374  TupleDesc            tuple_desc;
375  path_element_t        *path;
376   
377  /* stuff done only on the first call of the function */
378  if (SRF_IS_FIRSTCALL())
379    {
380      MemoryContext   oldcontext;
381      //int path_count;
382      int ret=-1;
383
384      // XXX profiling messages are not thread safe
385      profstart(prof_total);
386      profstart(prof_extract);
387
388      /* create a function context for cross-call persistence */
389      funcctx = SRF_FIRSTCALL_INIT();
390
391      /* switch to memory context appropriate for multiple function calls */
392      oldcontext = MemoryContextSwitchTo(funcctx->multi_call_memory_ctx);
393       
394
395      path = (path_element_t *)palloc(sizeof(path_element_t)*MAX_TOWNS);
396
397
398      ret = solve_tsp(text2char(PG_GETARG_TEXT_P(0)),
399                      text2char(PG_GETARG_TEXT_P(1)),
400                      PG_GETARG_INT32(2),
401                      path);
402
403      /* total number of tuples to be returned */
404      DBG("Counting tuples number\n");
405
406      funcctx->max_calls = total_tuples;
407
408      funcctx->user_fctx = path;
409
410      funcctx->tuple_desc = BlessTupleDesc(
411                              RelationNameGetTupleDesc("path_result"));
412      MemoryContextSwitchTo(oldcontext);
413    }
414
415  /* stuff done on every call of the function */
416  funcctx = SRF_PERCALL_SETUP();
417
418  call_cntr = funcctx->call_cntr;
419  max_calls = funcctx->max_calls;
420  tuple_desc = funcctx->tuple_desc;
421
422  path = (path_element_t *)funcctx->user_fctx;
423
424  DBG("Trying to allocate some memory\n");
425  DBG("call_cntr = %i, max_calls = %i\n", call_cntr, max_calls);
426
427  if (call_cntr < max_calls)    /* do when there is more left to send */
428    {
429      HeapTuple    tuple;
430      Datum        result;
431      Datum *values;
432      char* nulls;
433
434      DBG("Before allocation\n");
435
436      values = (Datum*)palloc(3*sizeof(Datum));
437      nulls = (char*)palloc(3*sizeof(char));
438
439      DBG("After allocation\n");
440
441      values[0] = Int32GetDatum(path[call_cntr].vertex_id);
442      nulls[0] = ' ';
443      values[1] = Int32GetDatum(path[call_cntr].edge_id);
444      nulls[1] = ' ';
445      values[2] = Float8GetDatum(path[call_cntr].cost);
446      nulls[2] = ' ';
447
448      DBG("Heap making\n");
449
450      tuple = heap_formtuple(tuple_desc, values, nulls);
451
452      DBG("Datum making\n");
453
454      /* make the tuple into a datum */
455      result = HeapTupleGetDatum(tuple);
456
457      DBG("VAL: %i\n, %i", values[0], result);
458      DBG("Trying to free some memory\n");
459   
460      /* clean up (this is not really necessary) */
461      pfree(values);
462      pfree(nulls);
463       
464
465      SRF_RETURN_NEXT(funcctx, result);
466    }
467  else    /* do when there is no more left */
468    {
469      DBG("Ending function\n");
470      profstop("store", prof_store);
471      profstop("total", prof_total);
472      DBG("Profiles stopped\n");
473
474      pfree(path);
475
476      DBG("Path cleared\n");
477
478      SRF_RETURN_DONE(funcctx);
479    }
480}
Note: See TracBrowser for help on using the browser.